Possible error : transform_labels() function in ot.da.BaseTransport
Describe the bug
transform_labels() output is not a probability distribution over labels, i.e. does not sum to 1.
To Reproduce
Check following code snippet
Code sample
import ot
import numpy as np
ot_sinkhorn = ot.da.SinkhornTransport()
Xs = np.array([[1, 0], [0, 0]])
ys = np.array([0, 1])
Xt = np.array([[1, 0], [2, 0], [3, 0]])
ot_sinkhorn.fit(Xs=Xs, Xt=Xt)
print(ot_sinkhorn.transform_labels(ys))
Output:
[[0.07946862 0.58719805]
[0.33333333 0.33333333]
[0.58719805 0.07946861]]
Expected behavior
Expected the rows of output to sum to 1 (labels being a probability distribution over classes)