Solving Many Optimal Transport Problems in Parallel — POT Python Optimal Transport 0.9.6 documentation
Note
Go to the end to download the full example code.
In some situations, one may want to solve many OT problems with the same structure (same number of samples, same cost function, etc.) at the same time.
In that case using a for loop to solve the problems sequentially is inefficient. This example shows how to use the batch solvers implemented in POT to solve many problems in parallel on CPU or GPU (even more efficient on GPU).
# Author: Paul Krzakala <paul.krzakala@gmail.com> # License: MIT License # sphinx_gallery_thumbnail_number = 1
Computing the Cost Matrices
We want to create a batch of optimal transport problems with \(n\) samples in \(d\) dimensions.
To do this, we first need to compute the cost matrices for each problem.
Note
A straightforward approach would be to use a Python loop and
ot.dist().
However, this is inefficient when working with batches.
Instead, you can directly use ot.batch.dist_batch(), which computes
all cost matrices in parallel.
Solving the Problems
Once the cost matrices are computed, we can solve the corresponding optimal transport problems.
Note
One option is to solve them sequentially with a Python loop using
ot.solve().
This is simple but inefficient for large batches.
Instead, you can use ot.batch.solve_batch(), which solves all
problems in parallel.
reg = 1.0 max_iter = 100 tol = 1e-3 # Naive approach results_values_list = [] for i in range(n_problems): res = ot.solve(M_list[i], reg=reg, max_iter=max_iter, tol=tol, reg_type="entropy") results_values_list.append(res.value_linear) # Batched approach results_batch = ot.solve_batch( M=M_batch, reg=reg, max_iter=max_iter, tol=tol, reg_type="entropy" ) results_values_batch = results_batch.value_linear assert np.allclose(np.array(results_values_list), results_values_batch, atol=tol * 10)
Comparing Computation Time
We now compare the runtime of the two approaches on larger problems.
Note
The speedup obtained with ot.batch can be even more
significant when computations are performed on a GPU.
from time import perf_counter n_problems = 128 n_samples = 8 dim = 2 reg = 10.0 max_iter = 1000 tol = 1e-3 samples_source = np.random.randn(n_problems, n_samples, dim) samples_target = samples_source + 0.1 * np.random.randn(n_problems, n_samples, dim) def benchmark_naive(samples_source, samples_target): start = perf_counter() for i in range(n_problems): M = ot.dist(samples_source[i], samples_target[i]) res = ot.solve(M, reg=reg, max_iter=max_iter, tol=tol, reg_type="entropy") end = perf_counter() return end - start def benchmark_batch(samples_source, samples_target): start = perf_counter() M_batch = ot.dist_batch(samples_source, samples_target) res_batch = ot.solve_batch( M=M_batch, reg=reg, max_iter=max_iter, tol=tol, reg_type="entropy" ) end = perf_counter() return end - start time_naive = benchmark_naive(samples_source, samples_target) time_batch = benchmark_batch(samples_source, samples_target) print(f"Naive approach time: {time_naive:.4f} seconds") print(f"Batched approach time: {time_batch:.4f} seconds")
Naive approach time: 0.2942 seconds Batched approach time: 0.0094 seconds
Gromov-Wasserstein
The ot.batch module also provides a batched Gromov-Wasserstein solver.
Note
This solver is not equivalent to calling ot.solve_gromov()
repeatedly in a loop.
Key differences:
ot.solve_gromov()Uses the conditional gradient algorithm. Each inner iteration relies on an exact EMD solver.ot.batch.solve_gromov_batch()Uses a proximal variant, where each inner iteration applies entropic regularization.
As a result:
ot.solve_gromov()is usually faster on CPUot.batch.solve_gromov_batch()is slower on CPU, but provides better objective values.
from ot import solve_gromov, solve_gromov_batch def benchmark_naive_gw(samples_source, samples_target): start = perf_counter() avg_value = 0 for i in range(n_problems): C1 = ot.dist(samples_source[i], samples_source[i]) C2 = ot.dist(samples_target[i], samples_target[i]) res = solve_gromov(C1, C2, max_iter=1000, tol=tol) avg_value += res.value avg_value /= n_problems end = perf_counter() return end - start, avg_value def benchmark_batch_gw(samples_source, samples_target): start = perf_counter() C1_batch = ot.dist_batch(samples_source, samples_source) C2_batch = ot.dist_batch(samples_target, samples_target) res_batch = solve_gromov_batch( C1_batch, C2_batch, reg=1, max_iter=100, max_iter_inner=50, tol=tol ) avg_value = np.mean(res_batch.value) end = perf_counter() return end - start, avg_value time_naive_gw, avg_value_naive_gw = benchmark_naive_gw(samples_source, samples_target) time_batch_gw, avg_value_batch_gw = benchmark_batch_gw(samples_source, samples_target) print(f"{'Method':<20}{'Time (s)':<15}{'Avg Value':<15}") print(f"{'Naive GW':<20}{time_naive_gw:<15.4f}{avg_value_naive_gw:<15.4f}") print(f"{'Batched GW':<20}{time_batch_gw:<15.4f}{avg_value_batch_gw:<15.4f}")
Method Time (s) Avg Value Naive GW 0.0897 0.7070 Batched GW 0.3746 0.2914
In summary: no more for loops!
import matplotlib.pyplot as plt fig, ax = plt.subplots(figsize=(4, 4)) ax.text(0.5, 0.5, "For", fontsize=160, ha="center", va="center", zorder=0) ax.axis("off") ax.plot([0, 1], [0, 1], color="red", linewidth=10, zorder=1) ax.plot([0, 1], [1, 0], color="red", linewidth=10, zorder=1) plt.show()

Total running time of the script: (0 minutes 0.818 seconds)