[MRG] Restructure and Augment Partial Gromov-Wasserstein solvers by cedricvincentcuaz · Pull Request #663 · PythonOT/POT
Types of changes
-
create duplicates of partial GW related functions from
ot/partial.pytoot/gromov/_partial.pywhile adding a depreciation notes and warnings withinot/partial.pyto maintain the API for now. This includes the following functions:partial_gromov_wasserstein,partial_gromov_wasserstein2,entropic_partial_gromov_wasserstein,entropic_partial_gromov_wasserstein2. -
Add generic (hidden function)
_transform_matrixreturning transformed structure matrices for GW inot/gromov/_utils.py(used to compute only once these transformations in pGW even though there is not constant terms a priori as with the GW problem) and use it to factor code ininit_matrixandsemirelaxed_init_matrix.
Then adapt solvers to mimic other GW related solvers in order to ease future integration of new solvers (e.g pFGW ones):
- Make marginals
pandqoptional and set by default to uniform ones. - Add
loss_funparameter within ['square_loss', 'kl_loss'] -> now solvers supportkl_losstoo. - Add
symmetryparameter within [None, True, False] -> now solvers support asymmetric structure matrices. - restructure implementation of solvers :
- gradient and loss computations now only rely on the generic corresponding functions for the gw loss (
gwloss,gwggrad), therefore functionsgwgrad_partialandgwloss_partialinot/partial.pywill also be depreciated (kept for now with a de. - Implement generic partial CG solver
partial_cginot/optim.py. - Improve docs for functions in
ot/optim.py. - Call to
solve_partial_gromov_linesearchfor the exact line-search of pGW. Note that the latter requires the computation of the gradient of the regularizer atG(like other cg functions) but also at the cg directionGc. This allows to deduce the next step gradient as a convex combinations of these previous gradients (should reduce the computation time of about 33 %), within thegeneric_cgsolver. This trick will be implemented for all (F)GW based regularizer in a next PR.
- gradient and loss computations now only rely on the generic corresponding functions for the gw loss (
Motivation and context / Related issue
How has this been tested (if it applies)
PR checklist
- I have read the CONTRIBUTING document.
- The documentation is up-to-date with the changes I made (check build artifacts).
- All tests passed, and additional code has been covered with new tests.
- I have added the PR and Issue fix to the RELEASES.md file.