[MRG] Fix barycenter mass by hichamjanati · Pull Request #375 · PythonOT/POT
Expand Up
@@ -490,6 +490,41 @@ def test_barycenter(nx, method, verbose, warn):
ot.bregman.barycenter(A_nx, M_nx, reg, log=True)
@pytest.mark.parametrize("method, verbose, warn", product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"], [True, False], [True, False])) def test_barycenter_assymetric_cost(nx, method, verbose, warn): n_bins = 20 # nb bins
# Gaussian distributions A = ot.datasets.make_1D_gauss(n_bins, m=30, s=10) # m= mean, s= std
# creating matrix A containing all distributions A = A[:, None]
# assymetric loss matrix + normalization rng = np.random.RandomState(42) M = rng.randn(n_bins, n_bins) ** 2 M /= M.max()
A_nx, M_nx = nx.from_numpy(A, M) reg = 1e-2
if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log": with pytest.raises(NotImplementedError): ot.bregman.barycenter(A_nx, M_nx, reg, method=method) else: # wasserstein bary_wass_np = ot.bregman.barycenter(A, M, reg, method=method, verbose=verbose, warn=warn) bary_wass, _ = ot.bregman.barycenter(A_nx, M_nx, reg, method=method, log=True) bary_wass = nx.to_numpy(bary_wass)
np.testing.assert_allclose(1, np.sum(bary_wass)) np.testing.assert_allclose(bary_wass, bary_wass_np)
ot.bregman.barycenter(A_nx, M_nx, reg, log=True)
@pytest.mark.parametrize("method, verbose, warn", product(["sinkhorn", "sinkhorn_log"], [True, False], [True, False])) Expand Down
@pytest.mark.parametrize("method, verbose, warn", product(["sinkhorn", "sinkhorn_stabilized", "sinkhorn_log"], [True, False], [True, False])) def test_barycenter_assymetric_cost(nx, method, verbose, warn): n_bins = 20 # nb bins
# Gaussian distributions A = ot.datasets.make_1D_gauss(n_bins, m=30, s=10) # m= mean, s= std
# creating matrix A containing all distributions A = A[:, None]
# assymetric loss matrix + normalization rng = np.random.RandomState(42) M = rng.randn(n_bins, n_bins) ** 2 M /= M.max()
A_nx, M_nx = nx.from_numpy(A, M) reg = 1e-2
if nx.__name__ in ("jax", "tf") and method == "sinkhorn_log": with pytest.raises(NotImplementedError): ot.bregman.barycenter(A_nx, M_nx, reg, method=method) else: # wasserstein bary_wass_np = ot.bregman.barycenter(A, M, reg, method=method, verbose=verbose, warn=warn) bary_wass, _ = ot.bregman.barycenter(A_nx, M_nx, reg, method=method, log=True) bary_wass = nx.to_numpy(bary_wass)
np.testing.assert_allclose(1, np.sum(bary_wass)) np.testing.assert_allclose(bary_wass, bary_wass_np)
ot.bregman.barycenter(A_nx, M_nx, reg, log=True)
@pytest.mark.parametrize("method, verbose, warn", product(["sinkhorn", "sinkhorn_log"], [True, False], [True, False])) Expand Down