fix the sign of gradient for kl gromov by KrzakalaPaul · Pull Request #610 · PythonOT/POT
Correct a sign error of the nx.set_gradient for gromov (and fused gromov) when loss_fun = 'kl_loss'.
The correct formula is:
from ot import gromov_wasserstein2,unif
import torch
C1 = torch.rand((2,2), requires_grad = False)
C2 = 1 - C1
C2.requires_grad = True
eta = 1e-1
for step in range(100):
loss = gromov_wasserstein2(C1=C1,C2=C2,p=unif(2,type_as=C2),q=unif(2,type_as=C2),loss_fun='square_loss')
grad = torch.autograd.grad(loss, C2)[0]
C2 = C2 - eta*grad
C2 = torch.clip(C2,0,1)
print(loss)
You will see that the gradient descent diverges. It converges when we fix the sign error.