Gromov-Wasserstein distance by ncourty · Pull Request #23 · PythonOT/POT
Conversation
Hi everyone,
This is a new implementation of the Gromov-Wasserstein distance, mostly programmed by Erwan Vautier and myself. In the next commit, I will add a new example on how to compute barycenters and also tests for this new functionality.
| * Joint OT matrix and mapping estimation [8]. | ||
| * Wasserstein Discriminant Analysis [11] (requires autograd + pymanopt). | ||
|
|
||
| * Gromov-Wasserstein distances [12] |
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
and barycenters
|
|
||
| [11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016). [Wasserstein Discriminant Analysis](https://arxiv.org/pdf/1608.08063.pdf). arXiv preprint arXiv:1608.08063. | ||
|
|
||
| [12] Peyré, Gabriel, Marco Cuturi, and Justin Solomon, [Gromov-Wasserstein averaging of kernel and distance matrices](http://proceedings.mlr.press/v48/peyre16.html) International Conference on Machine Learning (ICML). 2016. |
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Gabriel Peyré to be consistent
| """ | ||
| ==================== | ||
| Gromov-Wasserstein example | ||
| ==================== |
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not enough ===
| import numpy as np | ||
|
|
||
| import ot | ||
| import matplotlib.pylab as pl |
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
import pl before ot
|
|
||
| """ | ||
| Sample two Gaussian distributions (2D and 3D) | ||
| ==================== |
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not enough ===
it won't render well in sphinx
| For demonstration purpose, we sample two Gaussian distributions in 2- and 3-dimensional spaces. | ||
| """ | ||
|
|
||
| n = 30 # nb samples |
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
n -> n_samples
you won't need to write # nb samples :)
| Returns the value of L(a,b)=(1/2)*|a-b|^2 | ||
| """ | ||
|
|
||
| return (1 / 2) * (a - b)**2 |
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 / 2 will be 0 python 2
| return b | ||
|
|
||
| tens = -np.dot(h1(C1), T).dot(h2(C2).T) | ||
| tens = tens - tens.min() |
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tens -= tens.min()
|
|
||
| Parameters | ||
| ---------- | ||
| C1 : np.ndarray(ns,ns) |
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
C1 : ndarray, shape (ns, ns)
is the standard of numpydoc
| cpt = 0 | ||
| err = 1 | ||
|
|
||
| while (err > stopThr and cpt < numItermax): |
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
avoid while loops. Use for with break. It's much safer to avoid infinite loops
you can use for else syntax to capture the absence of a break
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok for this one I will keep the consistency with the rest of the optimization method (especially those in Bregman module)
| """ | ||
|
|
||
|
|
||
| def smacof_mds(C, dim, maxIter=3000, eps=1e-9): |
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maxIter -> max_iter
|
|
||
| Parameters | ||
| ---------- | ||
| C : np.ndarray(ns,ns) |
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
C : ndarray, shape (ns , ns)
| ---------- | ||
| C : np.ndarray(ns,ns) | ||
| dissimilarity matrix | ||
| dim : Integer |
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Integer -> int
| dissimilarity matrix | ||
| dim : Integer | ||
| dimension of the targeted space | ||
| maxIter : Maximum number of iterations of the SMACOF algorithm for a single run |
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
max_iter : int
Maximum number of iterations of the SMACOF algorithm for a single run
| Ct01 = [0 for i in range(2)] | ||
| for i in range(2): | ||
| Ct01[i] = ot.gromov.gromov_barycenters(N, [Cs[0], Cs[1]], [ | ||
| ps[0], ps[1]], p, lambdast[i], 'square_loss', 5e-4, numItermax=100, stopThr=1e-3) |
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
numItermax -> max_iter?
| triangle = spi.imread('../data/triangle.png').astype(np.float64) / 256 | ||
| fleche = spi.imread('../data/coeur.png').astype(np.float64) / 256 | ||
|
|
||
| shapes = [carre, rond, triangle, fleche] |
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess you meant : square, circle, triangle and arrow :)
@ncourty please go over the full diff about docstrings and naming. If you're ok with me bugging you :) I'll do one more pass when you did it.
| 'It.', 'Err') + '\n' + '-' * 19) | ||
| print('{:5d}|{:8e}|'.format(cpt, err)) | ||
|
|
||
| cpt = cpt + 1 |
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
cpt += 1
Nicolas Courty added 2 commits
September 1, 2017 15:37| square = spi.imread('../data/carre.png').astype(np.float64) / 256 | ||
| circle = spi.imread('../data/rond.png').astype(np.float64) / 256 | ||
| triangle = spi.imread('../data/triangle.png').astype(np.float64) / 256 | ||
| arrow = spi.imread('../data/coeur.png').astype(np.float64) / 256 |
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you rename maybe the png files? also I see arrow = coeur. Is this a bug?
|
|
||
| xs = ot.datasets.get_2D_samples_gauss(n_samples, mu_s, cov_s) | ||
|
|
||
| xt = xs[::-1] |
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would have written:
xt = xs[::-1].copy()
and removed the array below
| npos : ndarray, shape (R, dim) | ||
| Embedded coordinates of the interpolated point cloud (defined with one isometry) | ||
|
|
||
|
|
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove unnecessary empty lines here and one before Returns
| """ | ||
| Sample two Gaussian distributions (2D and 3D) | ||
| ============================================= | ||
| The Gromov-Wasserstein distance allows to compute distances with samples that do not belong to the same metric space. |
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
line too long
| tens : ndarray, shape (ns, nt) | ||
| \mathcal{L}(C1,C2) \otimes T tensor-matrix multiplication result | ||
|
|
||
|
|
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove empty lines
| ===================================== | ||
| Gromov-Wasserstein Barycenter example | ||
| ===================================== | ||
| This example is designed to show how to use the Gromov-Wassertsein distance |
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wassertsein -> Wasserstein
|
|
||
| def smacof_mds(C, dim, max_iter=3000, eps=1e-9): | ||
| """ | ||
| Returns an interpolated point cloud following the dissimilarity matrix C using SMACOF |
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| Embedded coordinates of the interpolated point cloud (defined with one isometry) | ||
| """ | ||
|
|
||
| rng = np.random.RandomState(seed=3) |
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you should expose the random_state and use check_random_state like sklearn does.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wait it's an example maybe it's not necessary...
| ---------- | ||
| p : ndarray, shape (N,) | ||
| weights in the targeted barycenter | ||
| lambdas : list of the S spaces' weights |
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bad format
| sample weights in the S spaces | ||
| p : ndarray, shape(N,) | ||
| weights in the targeted barycenter | ||
| lambdas : list of the S spaces' weights |
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bad format
| lambdas = np.asarray(lambdas, dtype=np.float64) | ||
|
|
||
| # Initialization of C : random SPD matrix | ||
| xalea = np.random.randn(N, 2) |
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
expose random_state to make results deterministic if one wants
Thanks for the careful reading @agramfort . And congrats for your NIPS paper :) See you in LA ?
Hello @ncourty ,
I think we should merge shortly since it has converged, could you please update from master ?
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good to me, you have taken into account all comments and the contribution is very nice for the toolbox.
I think we can merge.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters