[MRG] Add a geomloss wrapper for sinkhorn solver by rflamary · Pull Request #571 · PythonOT/POT

This is a first attempt to build on the awesome geomloss and Keops of @jeanfeydy. this function empirical_sinkhorn2_geomloss is a simple wrapper that will be called in ot.sovle_sample with method='geomloss' .

import numpy as np
import scipy as sp
import ot

n = 1000
rng = np.random.RandomState(0)

x = rng.randn(n, 2)
x2 = rng.randn(n//2, 2)+5

xb = torch.tensor(x, dtype=torch.float32)
xb2 = torch.tensor(x2, dtype=torch.float32)

a = torch.ones(n, dtype=torch.float32, requires_grad=True) / n
b = torch.ones(n//2, dtype=torch.float32, requires_grad=True) / (n//2)

#%%  empirical_sinkhorn2_geomloss wrapper for geomloss
reg=1

ot.tic()
value0, log0 = ot.bregman.empirical_sinkhorn2(xb, xb2, reg=reg, lazy=False, log=True)
ot.toc('Classical sinhorn solver : {}s')

ot.tic()
value, log = empirical_sinkhorn2_geomloss(xb, xb2, reg=reg, log= True)
ot.toc('Geomloss : {}s')
T = log['lazy_plan'] # recover lazy plan

#%% ot.solve_sample wrapper for geomloss
reg=1

# automatic solver
ot.tic()
sol = ot.solve_sample(x,x2, reg=reg, method='geomloss_auto', lazy=True)
ot.toc('Geomloss (automatic) : {}s')

# tensorized solver is fast but O(n^2) in memory
ot.tic()
sol = ot.solve_sample(x,x2, reg=reg, method='geomloss_tensorized', lazy=True)
ot.toc('Geomloss tensorized : {}s')

# online solver compute the distanec marix when necessary and is O(n) in memory
ot.tic()
sol = ot.solve_sample(x,x2, reg=reg, method='geomloss_online', lazy=True)
ot.toc('Geomloss online : {}s')

# multiscale is usually the fastest
ot.tic()
sol = ot.solve_sample(x,x2, reg=reg, method='geomloss_multiscale', lazy=True)
ot.toc('Geomloss multiscale : {}s')
[KeOps] Warning : Cuda libraries were not detected on the system ; using cpu only mode
Classical sinhorn solver : 0.6667807102203369s
Geomloss : 0.037354469299316406s
Geomloss (automatic) : 0.05784487724304199s
Geomloss tensorized : 0.052454471588134766s
Geomloss online : 0.2332136631011963s
Geomloss multiscale : 0.04137110710144043s

For the moment this solver is compatible only with pytorch and numpy (with pytorch conversion)