[DA] Sinkhorn L1L2 transport to work on JAX by kachayev · Pull Request #587 · PythonOT/POT

Types of changes

All operations required for SinkhornL1l2Transport to work on JAX are properly vectorized, including those implemented in the BaseTransport. In short

  • per-labels for-loops are vectorized using mask tensors, implementation is moved to a labels_to_masks helper with corresponding tests
  • a new backend method nan_to_num
  • JAX backend was removed from the exclusion list in BaseEstimator
  • a few enhancements for label normalization and related operations (including avoid unnecessary computations when normalizing labels)

Motivation and context / Related issue

The next step towards making domain adaptation methods to work on JAX backend, continues the work started with #507.

How has this been tested (if it applies)

  • test_sinkhorn_l1l2_transport_class test doesn't skip JAX backend
  • the test also updated to check semi-supervised use case
  • additional test cases for label_normalization and labels_to_masks helpers

PR checklist

  • I have read the CONTRIBUTING document.
  • The documentation is up-to-date with the changes I made (check build artifacts).
  • All tests passed, and additional code has been covered with new tests.
  • I have added the PR and Issue fix to the RELEASES.md file.

Additional Context

Working on the implementation I spotted the following (potential issue). It seems that the tests for semi-supervised DA, in fact, do not cover semi-supervised use case. They test the different between unsupervised (no labels for target) and supervised (target labels are available). For the test_sinkhorn_l1l2_transport_class specifically I did update the implementation to use partially masked labels for targets (see otda_semi). Does it covers the expected functionality correctly?