CUDA out of memory when using ot.sinkhorn2 as a loss function

Hi, I'm trying to implement emd/sinkhorn distance as the loss function for 2D matrices.
However, ot.sinkhorn2 causes CUDA out of memory error when it's being computed:

image

The ot.emd2 can also give this error when I use a larger data set.

By viewing nvidia-smi, when I try to train the same dataset, ot.emd2 uses up about 7G/12G memory (which is fine), while ot.sinkhorn2 uses 11G/12G memory and causes the error above.

Thank you!