Constants do not backpropagate through function emd2 in torch
Describe the bug
Backpropagating through the emd2 (earth mover distance) of pot does not take into account scaling factors that can be performed afterwards.
The gradient of $a \times \text{emd}(...)$ is the same as $\text{emd}(...)$.
To Reproduce
Here is a minimalistic code sample to reproduce the unexpected behaviour of backpropagation.
Code sample
import torch import ot # Fix the seed torch.manual_seed(0) # Number of samples / dimension of data N,d=5,2 # Generate random dummy data X=torch.randn(N,d).double() # Distance matrix M=ot.dist(X) # Create random coefficient vectors (normalised) that require gradients a=torch.abs(torch.randn(N,dtype=torch.float64)) b=torch.abs(torch.randn(N,dtype=torch.float64)) a=a/torch.sum(a) b=b/torch.sum(b) a.requires_grad=True b.requires_grad=True
# Compute Earth Mover distance emd=ot.emd2(a,b,M) # Backprop emd.backward() # Print gradients print(a.grad,b.grad)
tensor([ 4.0011, -4.1943, 0.0042, 1.6815, -1.1197], dtype=torch.float64) tensor([-4.0011, 4.1943, -0.0042, -1.6815, 1.1197], dtype=torch.float64)
# Now, do all the same operations, but seek to maximise the loss instead of minimise a.grad,b.grad=None,None # Adding a - sign here emd=-ot.emd2(a,b,M) emd.backward() # Print gradients, only the sign should change. It is not the case print(a.grad,b.grad)
tensor([ 4.0011, -4.1943, 0.0042, 1.6815, -1.1197], dtype=torch.float64) tensor([-4.0011, 4.1943, -0.0042, -1.6815, 1.1197], dtype=torch.float64)
# The same applies in fact to other constants a.grad,b.grad=None,None emd=100*ot.emd2(a,b,M) # Backprop emd.backward() print(a.grad,b.grad)
tensor([ 4.0011, -4.1943, 0.0042, 1.6815, -1.1197], dtype=torch.float64) tensor([-4.0011, 4.1943, -0.0042, -1.6815, 1.1197], dtype=torch.float64)
Expected behavior
If we multiply the result of the emd2 function by a constant $a$, then the gradient afterbackpropagation should also be scaled by that constant.
Environment (please complete the following information):
- OS: Linux
- Python version: 3.8.8
- How was POT installed: pip
Output of the following code snippet:
import platform; print(platform.platform()) import sys; print("Python", sys.version) import numpy; print("NumPy", numpy.__version__) import scipy; print("SciPy", scipy.__version__) import ot; print("POT", ot.__version__) import torch; print("Torch", torch.__version__)
Linux-5.11.0-40-generic-x86_64-with-glibc2.10
Python 3.8.8 (default, Feb 24 2021, 21:46:12)
[GCC 7.3.0]
NumPy 1.20.2
SciPy 1.6.1
POT 0.8.0
Torch 1.8.1+cu102
Additional context
We did not put it here in the example, but working on GPU device with torch yields the same result.