Free support barycenters by vivienseguy · Pull Request #56 · PythonOT/POT

rflamary

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you @vivienseguy for all the work.

I have several comments that need to be addressed before merging (discussed more in detail below).

But most of all we need a test in the test_ot.py file that call your function and check stuff like the size of the output and reasonable solution.


##############################################################################
# Compute free support barycenter
# -------------

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

---- needs to have the proper length for good documentation generation.


def barycenter(A, M, weights=None, verbose=False, log=False, solver='interior-point'):
"""Compute the entropic regularized wasserstein barycenter of distributions A
"""Compute the Wasserstein barycenter of distributions A

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch !


Parameters
----------
data_positions : list of (k_i,d) np.ndarray

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

names in the documentation different from the code : data_positions vs measures_locations

Stop threshol on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing log parameter in the function.

would be nice to return the list of the displacement_square_norm along the iteration in a dictionnary if log=True (similar behavior as barycenter function above that retruns a log)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add log if log=True

import numpy as np
import scipy as sp
import scipy.sparse as sps
import ot

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you shouldn't import pot inside a module.

something with relative path like

from .__init__ import emd

is far better since it imports the emd function from the __init__.py

return b


def free_support_barycenter(measures_locations, measures_weights, X_init, b, weights=None, numItermax=100, stopThr=1e-6, verbose=False):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also do b=None if the weights are supposed uniform (needs test an initialization in the function)

X_init = np.random.normal(0., 1., (k, d))
b = np.ones((k,)) / k

X = ot.lp.cvx.free_support_barycenter(measures_locations, measures_weights, X_init, b)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ot.lp.cvx.free_support_barycenter is very long.

you should import the function in ot.lp __init__.py and add it to __all__ like barycenter so that you can do ot.lp.free_support_barycenter