fix array bounds issue by AdrienCorenflos · Pull Request #170 · PythonOT/POT

Expand Up @@ -157,12 +157,12 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights, cost associated to the optimal transportation """ cdef double cost = 0. cdef int n = u_weights.shape[0] cdef int m = v_weights.shape[0] cdef Py_ssize_t n = u_weights.shape[0] cdef Py_ssize_t m = v_weights.shape[0]
cdef int i = 0 cdef Py_ssize_t i = 0 cdef double w_i = u_weights[0] cdef int j = 0 cdef Py_ssize_t j = 0 cdef double w_j = v_weights[0]
cdef double m_ij = 0. Expand All @@ -171,8 +171,8 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights, dtype=np.float64) cdef np.ndarray[long, ndim=2, mode="c"] indices = np.zeros((n + m - 1, 2), dtype=np.int) cdef int cur_idx = 0 while i < n and j < m: cdef Py_ssize_t cur_idx = 0 while True: if metric == 'sqeuclidean': m_ij = (u[i] - v[j]) * (u[i] - v[j]) elif metric == 'cityblock' or metric == 'euclidean': Expand All @@ -188,6 +188,8 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights, indices[cur_idx, 0] = i indices[cur_idx, 1] = j i += 1 if i == n: break w_j -= w_i w_i = u_weights[i] else: Expand All @@ -196,7 +198,10 @@ def emd_1d_sorted(np.ndarray[double, ndim=1, mode="c"] u_weights, indices[cur_idx, 0] = i indices[cur_idx, 1] = j j += 1 if j == m: break w_i -= w_j w_j = v_weights[j] cur_idx += 1 cur_idx += 1 return G[:cur_idx], indices[:cur_idx], cost