[MRG] Add set_gradients method for JAX backend. by AdrienCorenflos · Pull Request #278 · PythonOT/POT
Types of changes
- Docs change / refactoring / dependency upgrade
- Bug fix (non-breaking change which fixes an issue)
- New feature (non-breaking change which adds functionality)
- Breaking change (fix or feature that would cause existing functionality to change)
Motivation and context / Related issue
The set_gradient is possible in JAX.
How has this been tested (if it applies)
Added a modified unittest for JAX.
Checklist
- [ X ] The documentation is up-to-date with the changes I made.
- [ X ] I have read the CONTRIBUTING document.
- [ X ] All tests passed, and additional code has been covered with new tests.