[MRG] Extend unbalanced OT solvers for more flexibility by 6Ulm · Pull Request #539 · PythonOT/POT
Types of changes
- Modify all unbalanced solvers in
unbalanced.py(except for barycenter), so that the argumentreg_mcan take either a scalar, or an indexable object of length 1 or 2. - Edit
mm_unbalancedmethod so that it can be used for both unregularized and regularized problems (previously supported unregularized only). Also, fix an implementation error of matrix $K$ in the $L_2$ case. - Fix the KL formula and the corresponding gradient for L-BFGS-B solver.
- Add corresponding tests to
test_unbalanced.py.
Motivation and context / Related issue
- The current solvers for unbalanced OT have some limitations, which may restrict their usage in practice.
- The
sinkhorn_unbalancedandlbfgsb_unbalancedmethods only allow forreg_mto be a scalar, thus impose the same penalization on both marginals. - The
mm_unbalancedmethod has the same limitation and restrict only to the unregularized problem.
- The
_get_loss_unbalancedmethod uses incorrect formula of KL divergence, since we are working with unnormalized measures. As a result, the calculation of the gradient is also incorrect.
This PR fixes all of these issues. Moreover, as a byproduct, the new version of sinkhorn_unbalanced method also allows to solve the entropic balanced and semi-relaxed problems.
How has this been tested (if it applies)
Tested when reg_m is a scalar, indexable objects of length 1 and 2 (tuple, list, array/tensor).
PR checklist
- I have read the CONTRIBUTING document.
- The documentation is up-to-date with the changes I made (check build artifacts).
- All tests passed, and additional code has been covered with new tests.
- I have added the PR and Issue fix to the RELEASES.md file.