fix the sign of gradient for kl gromov by KrzakalaPaul · Pull Request #610 · PythonOT/POT

Correct a sign error of the nx.set_gradient for gromov (and fused gromov) when loss_fun = 'kl_loss'.
The correct formula is:

from ot import gromov_wasserstein2,unif
import torch 

C1 = torch.rand((2,2), requires_grad = False)
C2 = 1 - C1
C2.requires_grad = True

eta = 1e-1

for step in range(100):
    
    loss = gromov_wasserstein2(C1=C1,C2=C2,p=unif(2,type_as=C2),q=unif(2,type_as=C2),loss_fun='square_loss')
    grad = torch.autograd.grad(loss, C2)[0]
    C2 = C2 - eta*grad
    C2 = torch.clip(C2,0,1)
    
    print(loss)

You will see that the gradient descent diverges. It converges when we fix the sign error.