[DA] Sinkhorn L1L2 transport to work on JAX by kachayev · Pull Request #587 · PythonOT/POT
Types of changes
All operations required for SinkhornL1l2Transport to work on JAX are properly vectorized, including those implemented in the BaseTransport. In short
- per-labels for-loops are vectorized using mask tensors, implementation is moved to a
labels_to_maskshelper with corresponding tests - a new backend method
nan_to_num - JAX backend was removed from the exclusion list in BaseEstimator
- a few enhancements for label normalization and related operations (including avoid unnecessary computations when normalizing labels)
Motivation and context / Related issue
The next step towards making domain adaptation methods to work on JAX backend, continues the work started with #507.
How has this been tested (if it applies)
test_sinkhorn_l1l2_transport_classtest doesn't skip JAX backend- the test also updated to check semi-supervised use case
- additional test cases for
label_normalizationandlabels_to_maskshelpers
PR checklist
- I have read the CONTRIBUTING document.
- The documentation is up-to-date with the changes I made (check build artifacts).
- All tests passed, and additional code has been covered with new tests.
- I have added the PR and Issue fix to the RELEASES.md file.
Additional Context
Working on the implementation I spotted the following (potential issue). It seems that the tests for semi-supervised DA, in fact, do not cover semi-supervised use case. They test the different between unsupervised (no labels for target) and supervised (target labels are available). For the test_sinkhorn_l1l2_transport_class specifically I did update the implementation to use partially masked labels for targets (see otda_semi). Does it covers the expected functionality correctly?