[WIP] Fix gradient scaling bug in emd by rflamary · Pull Request #310 · PythonOT/POT

Expand Up @@ -126,6 +126,22 @@ def test_emd2_gradients(): assert b1.shape == b1.grad.shape assert M1.shape == M1.grad.shape
# Testing for bug #309, checking for scaling of gradient a2 = torch.tensor(a, requires_grad=True) b2 = torch.tensor(a, requires_grad=True) M2 = torch.tensor(M, requires_grad=True)
val = 10.0 * ot.emd2(a2, b2, M2)
val.backward()
assert np.allclose(10.0 * a1.grad.cpu().detach().numpy(), a2.grad.cpu().detach().numpy()) assert np.allclose(10.0 * b1.grad.cpu().detach().numpy(), b2.grad.cpu().detach().numpy()) assert np.allclose(10.0 * M1.grad.cpu().detach().numpy(), M2.grad.cpu().detach().numpy())

def test_emd_emd2(): # test emd and emd2 for simple identity Expand Down