ot.solve_sample_batch devices error
Describe the bug
I applied ot.solve_sample_batch but received an error RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!
To Reproduce
Code sample
import ot import torch x = torch.zeros((24, 100, 1)) y = torch.zeros((24, 100, 1)) ot.solve_sample_batch(x.cuda(), y.cuda(), reg=1)
Screenshots
Environment (please complete the following information):
I got that error from both Colab and Linux.
