[MRG] Extend unbalanced OT solvers for more flexibility by 6Ulm · Pull Request #539 · PythonOT/POT

Types of changes

  • Modify all unbalanced solvers in unbalanced.py (except for barycenter), so that the argument reg_m can take either a scalar, or an indexable object of length 1 or 2.
  • Edit mm_unbalanced method so that it can be used for both unregularized and regularized problems (previously supported unregularized only). Also, fix an implementation error of matrix $K$ in the $L_2$ case.
  • Fix the KL formula and the corresponding gradient for L-BFGS-B solver.
  • Add corresponding tests to test_unbalanced.py.

Motivation and context / Related issue

  1. The current solvers for unbalanced OT have some limitations, which may restrict their usage in practice.
  • The sinkhorn_unbalanced and lbfgsb_unbalanced methods only allow for reg_m to be a scalar, thus impose the same penalization on both marginals.
  • The mm_unbalanced method has the same limitation and restrict only to the unregularized problem.
  1. The _get_loss_unbalanced method uses incorrect formula of KL divergence, since we are working with unnormalized measures. As a result, the calculation of the gradient is also incorrect.

This PR fixes all of these issues. Moreover, as a byproduct, the new version of sinkhorn_unbalanced method also allows to solve the entropic balanced and semi-relaxed problems.

How has this been tested (if it applies)

Tested when reg_m is a scalar, indexable objects of length 1 and 2 (tuple, list, array/tensor).

PR checklist

  • I have read the CONTRIBUTING document.
  • The documentation is up-to-date with the changes I made (check build artifacts).
  • All tests passed, and additional code has been covered with new tests.
  • I have added the PR and Issue fix to the RELEASES.md file.