POT generating incorrect result for very simple OT problem
Apologies if I'm doing something stupid --- I don't think I am. The simple example
import numpy as np import ot phi = np.array((0.5, 0.5)) # distribution 1 psi = np.array((0.5, 0.5)) # distribution 2 c = ((2, 1), (1, 1)) c = np.array(c) pi = ot.emd(phi, psi, c)
produces the incorrect result
(Clearly we should send all mass at 1 to 2 and all mass at 2 to 1.)
Direct application of linear programming produces the correct result
array([[ 0. , 0.5], [ 0.5, -0. ]])
Here's the direct linear programming code
# Define parameters m = n = 2 # Vectorize matrix C c_vec = c.reshape((m * n, 1), order='F') # Construct matrix A by Kronecker product A1 = np.kron(np.ones((1, n)), np.identity(m)) A2 = np.kron(np.identity(n), np.ones((1, m))) A = np.vstack([A1, A2]) # Construct vector b b = np.hstack([phi, psi]) # Solve the primal problem res = linprog(c_vec, A_eq=A, b_eq=b, method='highs-ipm') # Print results pi = res.x.reshape((m,n), order='F')
Environment (please complete the following information):
Manjaro linux, POT installed via pip in Anaconda environment.
Output of the following code snippet:
import platform; print(platform.platform()) import sys; print("Python", sys.version) import numpy; print("NumPy", numpy.__version__) import scipy; print("SciPy", scipy.__version__) import ot; print("POT", ot.__version__)
Linux-5.13.19-2-MANJARO-x86_64-with-glibc2.17
Python 3.8.12 (default, Oct 12 2021, 13:49:34)
[GCC 7.5.0]
NumPy 1.20.3
SciPy 1.7.1
POT 0.8.0