GPU changes: by toto6 · Pull Request #32 · PythonOT/POT
added 4 commits
September 13, 2017 21:13- Replace cudamat by cupy for GPU implementations (cupy is still in active development, while cudamat is not) - Use the new DA class instead of the old deprecated one TODO for another PR: - Performances are still a bit lower than with cudamat (even if better than CPU for large matrices). Some speedups should be possible by tweaking the code
Add function pairwiseEuclidean that can be used with numpy or cupy. cupy (GPU) is used if parameter gpu==True and cupy is available. Otherwise compute with numpy. This function is faster than scipy.spatial.distance.cdist for sqeuclidean even when computing using the CPU (numpy).
TODO: - add parameter "gpu" in init of all classes extending BaseTransport - pass parameter "gpu" to function pairwiseEuclidean - change in file bregman.py the function sinkhorn_knopp to use cupy or numpy - change in file da.py the function sinkhorn_lpl1_mm to use cupy or numpy - same but for other functions...
- modified sinkhorn knopp code to be executed on numpy or cupy depending on the type of input matrices - at the moment GPU version is slow compared to CPU. with the test I added I obtain these results: ``` Normal, time: 4.96 sec GPU, time: 4.65 sec ``` - TODO: - improve performances of sinkhorn knopp for GPU - add gpu support for LpL1
Before ``` Normal, time: 4.96 sec GPU, time: 4.65 sec ``` After ``` Normal, time: 4.21 sec GPU, time: 3.45 sec ```
Before ``` Normal, time: 4.21 sec GPU, time: 3.45 sec ``` After ``` Normal, time: 3.70 sec GPU, time: 2.65 sec ```
Improve the benchmark comparison script between CPU and GPU.
For me it now output the following:
scipy's cdist, time: 10.28 sec
pairwiseEuclidean CPU, time: 0.63 sec
pairwiseEuclidean GPU, time: 0.33 sec
Sinkhorn CPU, time: 3.58 sec
Sinkhorn GPU, time: 2.63 sec
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters. Learn more about bidirectional Unicode characters