[MRG] Update example GAN to avoid the 10 minute CircleCI limit by rflamary · Pull Request #258 · PythonOT/POT

Expand Up @@ -50,6 +50,7 @@
import numpy as np import matplotlib.pyplot as pl import matplotlib.animation as animation import torch from torch import nn import ot Expand Down Expand Up @@ -112,10 +113,10 @@ def forward(self, x):

G = Generator() optimizer = torch.optim.RMSprop(G.parameters(), lr=0.001) optimizer = torch.optim.RMSprop(G.parameters(), lr=0.00019, eps=1e-5)
# number of iteration and size of the batches n_iter = 500 n_iter = 200 # set to 200 for doc build but 1000 is better ;) size_batch = 500
# generate statis samples to see their trajectory along training Expand Down Expand Up @@ -167,7 +168,7 @@ def forward(self, x):
pl.figure(3, (10, 10))
ivisu = [0, 10, 50, 100, 150, 200, 300, 400, 499] ivisu = [0, 10, 25, 50, 75, 125, 15, 175, 199]
for i in range(9): pl.subplot(3, 3, i + 1) Expand All @@ -179,6 +180,37 @@ def forward(self, x): if i == 0: pl.legend()
# %% # Animate trajectories of generated samples along iteration # -------------------------------------------------------
pl.figure(4, (8, 8))

def _update_plot(i): pl.clf() pl.scatter(xd[:, 0], xd[:, 1], label='Data samples from $\mu_d$', alpha=0.1) pl.scatter(xvisu[i, :, 0], xvisu[i, :, 1], label='Data samples from $G\#\mu_n$', alpha=0.5) pl.xticks(()) pl.yticks(()) pl.xlim((-1.5, 1.5)) pl.ylim((-1.5, 1.5)) pl.title('Iter. {}'.format(i)) return 1

i = 0 pl.scatter(xd[:, 0], xd[:, 1], label='Data samples from $\mu_d$', alpha=0.1) pl.scatter(xvisu[i, :, 0], xvisu[i, :, 1], label='Data samples from $G\#\mu_n$', alpha=0.5) pl.xticks(()) pl.yticks(()) pl.xlim((-1.5, 1.5)) pl.ylim((-1.5, 1.5)) pl.title('Iter. {}'.format(ivisu[i]))

ani = animation.FuncAnimation(pl.gcf(), _update_plot, n_iter, interval=100, repeat_delay=2000)
# %% # Generate and visualize data # --------------------------- Expand All @@ -188,7 +220,7 @@ def forward(self, x): xn = torch.randn(size_batch, 2) x = G(xn).detach().numpy()
pl.figure(4) pl.figure(5) pl.scatter(xd[:, 0], xd[:, 1], label='Data samples from $\mu_d$', alpha=0.5) pl.scatter(x[:, 0], x[:, 1], label='Data samples from $G\#\mu_n$', alpha=0.5) pl.title('Sources and Target distributions') Expand Down