ot.gromov.gromov_wasserstein2 loss does not perform backprop with torch CUDA tensor
As title. Following is a short snippet to reproduce the error.
import numpy as np import ot import torch from ot.gromov import gromov_wasserstein2 def gw_pytorch_exam(C1, C2, a1, a2, device, n_iter=1000, lr=1e-2): C1_torch = torch.tensor(C1, device=device, requires_grad=True) C2_torch = torch.tensor(C2, device=device) a1_torch = torch.tensor(a1, device=device) a2_torch = torch.tensor(a2, device=device) for i in range(n_iter): loss = gromov_wasserstein2(C1_torch, C2_torch, a1_torch, a2_torch) loss.backward() with torch.no_grad(): grad = C1_torch.grad C1_torch -= grad * lr C1_torch.grad.zero_() C1_torch.data = torch.clamp(C1_torch, 0, 1) return C1_torch if torch.cuda.is_available(): device = torch.device("cuda:0") else: device = torch.device("cpu") # maybe should disable this to force GPU usage n = 10 C1 = np.eye(n) C2 = np.random.randn(n, n) a = ot.unif(n) C1 = gw_pytorch_exam(C1, C2, a, a, device)
Running this code returns RuntimeError
36 a = ot.unif(n) ---> 37 C1 = gw_pytorch_exam(C1, C2, a, a, device) <ipython-input-3-afc0d26d5054> in gw_pytorch_exam(C1, C2, a1, a2, device, n_iter, lr) 16 for i in range(n_iter): 17 loss = gromov_wasserstein2(C1_torch, C2_torch, a1_torch, a2_torch) ---> 18 loss.backward() 19 with torch.no_grad(): 20 grad = C1_torch.grad ~/.local/miniconda3/envs/test_pot/lib/python3.8/site-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph) 219 retain_graph=retain_graph, 220 create_graph=create_graph) --> 221 torch.autograd.backward(self, gradient, retain_graph, create_graph) 222 223 def register_hook(self, hook): ~/.local/miniconda3/envs/test_pot/lib/python3.8/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables) 128 retain_graph = create_graph 129 --> 130 Variable._execution_engine.run_backward( 131 tensors, grad_tensors_, retain_graph, create_graph, 132 allow_unreachable=True) # allow_unreachable flag ~/.local/miniconda3/envs/test_pot/lib/python3.8/site-packages/torch/autograd/function.py in apply(self, *args) 87 def apply(self, *args): 88 # _forward_cls is defined by derived class ---> 89 return self._forward_cls.backward(self, *args) # type: ignore 90 91 ~/.local/miniconda3/envs/test_pot/lib/python3.8/site-packages/ot/backend.py in backward(ctx, grad_output) 1381 def backward(ctx, grad_output): 1382 # the gradients are grad -> 1383 return (None, None) + tuple(g * grad_output for g in ctx.grads) 1384 1385 self.ValFunction = ValFunction ~/.local/miniconda3/envs/test_pot/lib/python3.8/site-packages/ot/backend.py in <genexpr>(.0) 1381 def backward(ctx, grad_output): 1382 # the gradients are grad -> 1383 return (None, None) + tuple(g * grad_output for g in ctx.grads) 1384 1385 self.ValFunction = ValFunction RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
PyTorch: 1.7.0
POT: 0.8.1
CUDA: 10.1 on NVIDIA Tesla P100