[MRG] Add a geomloss wrapper for sinkhorn solver by rflamary · Pull Request #571 · PythonOT/POT
This is a first attempt to build on the awesome geomloss and Keops of @jeanfeydy. this function empirical_sinkhorn2_geomloss is a simple wrapper that will be called in ot.sovle_sample with method='geomloss' .
import numpy as np import scipy as sp import ot n = 1000 rng = np.random.RandomState(0) x = rng.randn(n, 2) x2 = rng.randn(n//2, 2)+5 xb = torch.tensor(x, dtype=torch.float32) xb2 = torch.tensor(x2, dtype=torch.float32) a = torch.ones(n, dtype=torch.float32, requires_grad=True) / n b = torch.ones(n//2, dtype=torch.float32, requires_grad=True) / (n//2) #%% empirical_sinkhorn2_geomloss wrapper for geomloss reg=1 ot.tic() value0, log0 = ot.bregman.empirical_sinkhorn2(xb, xb2, reg=reg, lazy=False, log=True) ot.toc('Classical sinhorn solver : {}s') ot.tic() value, log = empirical_sinkhorn2_geomloss(xb, xb2, reg=reg, log= True) ot.toc('Geomloss : {}s') T = log['lazy_plan'] # recover lazy plan #%% ot.solve_sample wrapper for geomloss reg=1 # automatic solver ot.tic() sol = ot.solve_sample(x,x2, reg=reg, method='geomloss_auto', lazy=True) ot.toc('Geomloss (automatic) : {}s') # tensorized solver is fast but O(n^2) in memory ot.tic() sol = ot.solve_sample(x,x2, reg=reg, method='geomloss_tensorized', lazy=True) ot.toc('Geomloss tensorized : {}s') # online solver compute the distanec marix when necessary and is O(n) in memory ot.tic() sol = ot.solve_sample(x,x2, reg=reg, method='geomloss_online', lazy=True) ot.toc('Geomloss online : {}s') # multiscale is usually the fastest ot.tic() sol = ot.solve_sample(x,x2, reg=reg, method='geomloss_multiscale', lazy=True) ot.toc('Geomloss multiscale : {}s')
[KeOps] Warning : Cuda libraries were not detected on the system ; using cpu only mode
Classical sinhorn solver : 0.6667807102203369s
Geomloss : 0.037354469299316406s
Geomloss (automatic) : 0.05784487724304199s
Geomloss tensorized : 0.052454471588134766s
Geomloss online : 0.2332136631011963s
Geomloss multiscale : 0.04137110710144043s
For the moment this solver is compatible only with pytorch and numpy (with pytorch conversion)