Inconsistent of `log` in entropic_gromov_barycenters and `gromov_barycenters` when cal gw with log information
Describe the bug
To Reproduce
Steps to reproduce the behavior:
- Calculating the barycenter with the optional arg
log=True. log=Trueingromov_wassersteinreturns an additional log dictionary, similar asentropic_gromov_wasserstein
Screenshots
Code sample
import networkx as nx import numpy as np from scipy.sparse.csgraph import shortest_path from ot.gromov import gromov_barycenters Gs = [nx.cycle_graph(4)] Ds = [shortest_path(nx.adjacency_matrix(g)) for g in Gs] ps = [np.ones(4) / 4] lambdas = np.ones(len(Gs)) / len(Gs) N = 4 p = np.ones(N) / N C = gromov_barycenters(N, Ds, ps, p, lambdas, "square_loss", log=True)
Expected behavior
- the internal log information from
gromov_barycenteris not necessary. - only the error in each iteration in the while loop need to be recorded.
- return C if log is False, return C, {"err": [...]} if log=True
Environment (please complete the following information):
- OS (e.g. MacOS, Windows, Linux):
- Python version:
- How was POT installed (source,
pip,conda): - Build command you used (if compiling from source):
- Only for GPU related bugs:
- CUDA version:
- GPU models and configuration:
- Any other relevant information:
Output of the following code snippet:
import platform; print(platform.platform()) import sys; print("Python", sys.version) import numpy; print("NumPy", numpy.__version__) import scipy; print("SciPy", scipy.__version__) import ot; print("POT", ot.__version__) # output: Linux-5.4.0-70-generic-x86_64-with-debian-bullseye-sid Python 3.7.10 (default, Feb 26 2021, 18:47:35) [GCC 7.3.0] NumPy 1.19.2 SciPy 1.6.2 POT 0.7.0
Additional context
The issue happens in the version 0.7.0, but I checked the code in the latest version (0.8.0).
The problem exists as well.
Issue happens in the following lines when log=True
| T = [gromov_wasserstein(Cs[s], C, ps[s], p, loss_fun, | |
| numItermax=max_iter, stopThr=1e-5, verbose=verbose, log=log) for s in range(S)] |
| if log: | |
| log['err'].append(err) |

