[MRG] Gromov_Wasserstein2 not performing backward properly on GPU by ncassereau · Pull Request #352 · PythonOT/POT
Types of changes
The backpropagation is not working in gromov_wasserstein2 if the given tensors are located on a GPU. This is due to the fact that part of the computation is performed with numpy and the device was forgotten when casting back to torch.
Motivation and context / Related issue
Resolves #351
How has this been tested (if it applies)
PR checklist
- I have read the CONTRIBUTING document.
- The documentation is up-to-date with the changes I made (check build artifacts).
- All tests passed, and additional code has been covered with new tests.
- I have added the PR and Issue fix to the RELEASES.md file.