[MRG] Add exact line-search for (f)gw solvers with kl_loss by cedricvincentcuaz · Pull Request #556 · PythonOT/POT

Skip to content

Navigation Menu

Sign in

Appearance settings

Conversation

@cedricvincentcuaz

Types of changes

  • Add exact line-search for ot.gromov.gromov_wasserstein and ot.gromov.fused_gromov_wasserstein when loss_fun = 'kl_loss' which can be called by setting armijo=False(default for the method). Note that these solvers ignored the provided armijo parameter with this loss enforcing armijo=True.
  • Extend ot.gromov.solve_gromov_linesearch to any loss with common decomposition.
  • Adapt the behavior of dependent solvers that also enforced armijo=True: ot.gromov.gromov_barycenters and ot.gromov.fgw_barycenters
  • Add gradients in this setting for ot.gromov.gromov_wasserstein2 and ot.gromov.fused_gromov_wasserstein2

Motivation and context / Related issue

How has this been tested (if it applies)

  • Extended existing test for ot.gromov.gromov_wasserstein and ot.gromov.fused_gromov_wasserstein with kl loss, looping over armijo values.

PR checklist

  • I have read the CONTRIBUTING document.
  • The documentation is up-to-date with the changes I made (check build artifacts).
  • All tests passed, and additional code has been covered with new tests.
  • I have added the PR and Issue fix to the RELEASES.md file.

@codecov

Codecov Report

Merging #556 (e3f35cd) into master (53dde7a) will not change coverage.
The diff coverage is 100.00%.

Additional details and impacted files
@@           Coverage Diff           @@
##           master     #556   +/-   ##
=======================================
  Coverage   96.49%   96.49%           
=======================================
  Files          67       67           
  Lines       14663    14663           
=======================================
  Hits        14149    14149           
  Misses        514      514           

1 participant

@cedricvincentcuaz