[MRG] Add implicit Sinkhorn gradients by rflamary · Pull Request #605 · PythonOT/POT
Types of changes
This PR aims at
- implementing the detach function in the backend to allow speedup on CPU/GPU in some solvers (which was already done in a previous PR but with limited doc).
- Implement variants of Sinkhorn where computations are detached and gradients at convergence is returned instead
This PR should solve #565 and greatly limit memory for sinkhorn when computing gradienst wrt the value.
In order to use implicit diffeerntiation one needs to set the grad parameter in ot.solveand ot.solve_sampleas such
sol = ot.solve(M, a, b, reg=10, grad='implicit') sol.value.backward() # beware with grad='implicit', sol.value_linear and sol.plan are not differentiable (not implemented yet).
On a simple example with pytorch arrays with required gradients, I has a 1000x gain in memory for solving the problem when a large number of sinkhorn operations are needed.
Motivation and context / Related issue
How has this been tested (if it applies)
PR checklist
- I have read the CONTRIBUTING document.
- The documentation is up-to-date with the changes I made (check build artifacts).
- All tests passed, and additional code has been covered with new tests.
- I have added the PR and Issue fix to the RELEASES.md file.