[MRG] New ot.gpu with cupy by rflamary · Pull Request #67 · PythonOT/POT

The PR is a cupy implementation of the functions currently implemented in ot.gpu. I also removed all the classes that were deprecated anyways. It still needs proper updated test but i like this solution since it stays mostly compatible with the old ot.gpu.

I have received a large number of queries about ot.gpu but cudamat is not maintained and the problem will only grow so we need to do something before release 0.5.

This solution is far less elegant than PR #32 of @toto6 with all the decorators but having a cupy specific implementation leaves more room for code optimization than a generic implementation IMHO. Which means that we can make it better in the future without compromizing the numpy implmentation.

I give an example of use for the ot.gpu functions below with different format for input/output, i.e. if there are numpy.array of cupy.array . The output is obtained on my Titan X GPU after two run of the script in ipython.

import numpy as np
import pylab as pl
import ot
import ot.gpu

#%%
n=2000

tp=np.float32

xs=np.random.randn(n,2).astype(tp)
xt=np.random.randn(n,2).astype(tp)

w=ot.unif(n)

lab=np.zeros(n)
lab[n//2:]=1


print('Upload data to GPU:')
print('===================')
ot.tic()
xs2,xt2= ot.gpu.to_gpu(xs,xt)
ot.toc()

#%% test dist computation

ot.tic()
M=ot.dist(xs.copy(),xt.copy())
t0=ot.toq()


ot.tic()
M1=ot.gpu.dist(xs.copy(),xt.copy(),to_numpy=True)
t1=ot.toq()

ot.tic()
M2=ot.gpu.dist(xs.copy(),xt.copy(),to_numpy=False)
t2=ot.toq()

ot.tic()
M3=ot.gpu.dist(xs2,xt2,to_numpy=False)
t3=ot.toq()

print('\nDist computation:')
print('===================')
print('CPU                     : {:1.4f}s'.format(t0))
print('GPU (src=cpu,tgt=cpu)   : {:1.4f}s (x{:1.2f})'.format(t1,t0/t1))
print('GPU (src=cpu,tgt=gpu)   : {:1.4f}s (x{:1.2f})'.format(t2,t0/t2))
print('GPU (src=gpu,tgt=gpu)   : {:1.4f}s (x{:1.2f})'.format(t3,t0/t3))
print('Err= {:e}'.format(np.abs(M-M1).max()))

#%% Sinkhorn computation

reg=1

ot.tic()
G=ot.sinkhorn(w,w,M.copy(),reg)
t0=ot.toq()

ot.tic()
G1=ot.gpu.sinkhorn(w,w,M.copy(),reg,to_numpy=True)
t1=ot.toq()

ot.tic()
G2=ot.gpu.sinkhorn(w,w,M.copy(),reg,to_numpy=False)
t2=ot.toq()

ot.tic()
G3=ot.gpu.sinkhorn(w,w,M3,reg,to_numpy=False)
t2=ot.toq()

print('\nSinkhorn computation:')
print('=======================')
print('CPU                     : {:1.4f}s'.format(t0))
print('GPU (src=cpu,tgt=cpu)   : {:1.4f}s (x{:1.2f})'.format(t1,t0/t1))
print('GPU (src=cpu,tgt=gpu)   : {:1.4f}s (x{:1.2f})'.format(t2,t0/t2))
print('GPU (src=gpu,tgt=gpu)   : {:1.4f}s (x{:1.2f})'.format(t3,t0/t3))
print('Err= {:e}'.format(np.abs(G-G1).max()))


#%% test sinkhorn multi distrub

reg=1

w2=np.random.rand(n,20)
w2/=w2.sum(0,keepdims=True)

ot.tic()
wass=ot.sinkhorn(w,w2,M.copy(),reg)
t0=ot.toq()

ot.tic()
wass1=ot.gpu.sinkhorn(w,w2,M.copy(),reg,to_numpy=True)
t1=ot.toq()

ot.tic()
wass2=ot.gpu.sinkhorn(w,w2,M.copy(),reg,to_numpy=False)
t2=ot.toq()

ot.tic()
wass2=ot.gpu.sinkhorn(w,w2,M3,reg,to_numpy=False)
t2=ot.toq()

print('\nSinkhorn multiple target:')
print('==========================')
print('CPU                     : {:1.4f}s'.format(t0))
print('GPU (src=cpu,tgt=cpu)   : {:1.4f}s (x{:1.2f})'.format(t1,t0/t1))
print('GPU (src=cpu,tgt=gpu)   : {:1.4f}s (x{:1.2f})'.format(t2,t0/t2))
print('GPU (src=gpu,tgt=gpu)   : {:1.4f}s (x{:1.2f})'.format(t3,t0/t3))
print('Err= {:e}'.format(np.abs(wass-wass1).max()))


#%
ot.tic()
G1p=ot.da.sinkhorn_lpl1_mm(w,lab,w,M.copy(),reg)
t0=ot.toq()

ot.tic()
G1p1=ot.gpu.da.sinkhorn_lpl1_mm(w,lab,w,M.copy(),reg,to_numpy=True)
t1=ot.toq()

ot.tic()
G1p2=ot.gpu.da.sinkhorn_lpl1_mm(w,lab,w,M.copy(),reg,to_numpy=False)
t2=ot.toq()

ot.tic()
G1p2=ot.gpu.da.sinkhorn_lpl1_mm(w,lab,w,M3,reg,to_numpy=False)
t3=ot.toq()

print('\nSinkhorn lpl1 :')
print('==========================')
print('CPU                     : {:1.4f}s'.format(t0))
print('GPU (src=cpu,tgt=cpu)   : {:1.4f}s (x{:1.2f})'.format(t1,t0/t1))
print('GPU (src=cpu,tgt=gpu)   : {:1.4f}s (x{:1.2f})'.format(t2,t0/t2))
print('GPU (src=gpu,tgt=gpu)   : {:1.4f}s (x{:1.2f})'.format(t3,t0/t3))
print('Err= {:e}'.format(np.abs(G1p-G1p1).max()))
Upload data to GPU:
===================
Elapsed time : 0.28782010078430176 s

Dist computation:
===================
CPU                     : 0.1933s
GPU (src=cpu,tgt=cpu)   : 0.5164s (x0.37)
GPU (src=cpu,tgt=gpu)   : 0.0010s (x184.93)
GPU (src=gpu,tgt=gpu)   : 0.0011s (x180.36)
Err= 0.000000e+00

Sinkhorn computation:
=======================
CPU                     : 1.8513s
GPU (src=cpu,tgt=cpu)   : 0.6724s (x2.75)
GPU (src=cpu,tgt=gpu)   : 0.2524s (x7.33)
GPU (src=gpu,tgt=gpu)   : 0.0011s (x1727.06)
Err= 1.985125e-12

Sinkhorn multiple target:
==========================
CPU                     : 12.7924s
GPU (src=cpu,tgt=cpu)   : 1.1502s (x11.12)
GPU (src=cpu,tgt=gpu)   : 0.9587s (x13.34)
GPU (src=gpu,tgt=gpu)   : 0.0011s (x11933.96)
Err= 1.294231e-09

Sinkhorn lpl1 :
==========================
CPU                     : 22.6899s
GPU (src=cpu,tgt=cpu)   : 2.9365s (x7.73)
GPU (src=cpu,tgt=gpu)   : 2.7254s (x8.33)
GPU (src=gpu,tgt=gpu)   : 2.5752s (x8.81)
Err= 2.574980e-19