[MRG] Add set_gradients method for JAX backend. by AdrienCorenflos · Pull Request #278 · PythonOT/POT

Types of changes

  • Docs change / refactoring / dependency upgrade
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to change)

Motivation and context / Related issue

The set_gradient is possible in JAX.

#277

How has this been tested (if it applies)

Added a modified unittest for JAX.

Checklist

  • [ X ] The documentation is up-to-date with the changes I made.
  • [ X ] I have read the CONTRIBUTING document.
  • [ X ] All tests passed, and additional code has been covered with new tests.