ot.solve_sample_batch devices error

Describe the bug

I applied ot.solve_sample_batch but received an error RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

To Reproduce

Code sample

import ot
import torch

x = torch.zeros((24, 100, 1))
y = torch.zeros((24, 100, 1))

ot.solve_sample_batch(x.cuda(), y.cuda(), reg=1)

Screenshots

Image

Environment (please complete the following information):

I got that error from both Colab and Linux.