[MRG] add kl_loss to all semi-relaxed (f)gw solvers by cedricvincentcuaz · Pull Request #559 · PythonOT/POT

Expand Up @@ -56,7 +56,6 @@ def semirelaxed_gromov_wasserstein(C1, C2, p=None, loss_fun='square_loss', symme If let to its default value None, uniform distribution is taken. loss_fun : str loss function used for the solver either 'square_loss' or 'kl_loss'. 'kl_loss' is not implemented yet and will raise an error. symmetric : bool, optional Either C1 and C2 are to be assumed symmetric or not. If let to its default None value, a symmetry test will be conducted. Expand Down Expand Up @@ -92,8 +91,6 @@ def semirelaxed_gromov_wasserstein(C1, C2, p=None, loss_fun='square_loss', symme "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs" International Conference on Learning Representations (ICLR), 2022. """ if loss_fun == 'kl_loss': raise NotImplementedError() arr = [C1, C2] if p is not None: arr.append(list_to_array(p)) Expand Down Expand Up @@ -139,7 +136,7 @@ def df(G): return 0.5 * (gwggrad(constC + marginal_product_1, hC1, hC2, G, nx) + gwggrad(constCt + marginal_product_2, hC1t, hC2t, G, nx))
def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): return solve_semirelaxed_gromov_linesearch(G, deltaG, cost_G, C1, C2, ones_p, M=0., reg=1., nx=nx, **kwargs) return solve_semirelaxed_gromov_linesearch(G, deltaG, cost_G, hC1, hC2, ones_p, M=0., reg=1., fC2t=fC2t, nx=nx, **kwargs)
if log: res, log = semirelaxed_cg(p, q, 0., 1., f, df, G0, line_search, log=True, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs) Expand Down Expand Up @@ -190,7 +187,6 @@ def semirelaxed_gromov_wasserstein2(C1, C2, p=None, loss_fun='square_loss', symm If let to its default value None, uniform distribution is taken. loss_fun : str loss function used for the solver either 'square_loss' or 'kl_loss'. 'kl_loss' is not implemented yet and will raise an error. symmetric : bool, optional Either C1 and C2 are to be assumed symmetric or not. If let to its default None value, a symmetry test will be conducted. Expand Down Expand Up @@ -243,7 +239,12 @@ def semirelaxed_gromov_wasserstein2(C1, C2, p=None, loss_fun='square_loss', symm if loss_fun == 'square_loss': gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T)) gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T)) srgw = nx.set_gradients(srgw, (C1, C2), (gC1, gC2))
elif loss_fun == 'kl_loss': gC1 = nx.log(C1 + 1e-15) * nx.outer(p, p) - nx.dot(T, nx.dot(nx.log(C2 + 1e-15), T.T)) gC2 = nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q)
srgw = nx.set_gradients(srgw, (C1, C2), (gC1, gC2))
if log: return srgw, log_srgw Expand Down Expand Up @@ -291,7 +292,6 @@ def semirelaxed_fused_gromov_wasserstein( If let to its default value None, uniform distribution is taken. loss_fun : str loss function used for the solver either 'square_loss' or 'kl_loss'. 'kl_loss' is not implemented yet and will raise an error. symmetric : bool, optional Either C1 and C2 are to be assumed symmetric or not. If let to its default None value, a symmetry test will be conducted. Expand Down Expand Up @@ -332,9 +332,6 @@ def semirelaxed_fused_gromov_wasserstein( "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs" International Conference on Learning Representations (ICLR), 2022. """ if loss_fun == 'kl_loss': raise NotImplementedError()
arr = [M, C1, C2] if p is not None: arr.append(list_to_array(p)) Expand Down Expand Up @@ -382,7 +379,7 @@ def df(G):
def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): return solve_semirelaxed_gromov_linesearch( G, deltaG, cost_G, C1, C2, ones_p, M=(1 - alpha) * M, reg=alpha, nx=nx, **kwargs) G, deltaG, cost_G, hC1, hC2, ones_p, M=(1 - alpha) * M, reg=alpha, fC2t=fC2t, nx=nx, **kwargs)
if log: res, log = semirelaxed_cg(p, q, (1 - alpha) * M, alpha, f, df, G0, line_search, log=True, numItermax=max_iter, stopThr=tol_rel, stopThr2=tol_abs, **kwargs) Expand Down Expand Up @@ -434,7 +431,6 @@ def semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p=None, loss_fun='square_lo If let to its default value None, uniform distribution is taken. loss_fun : str, optional loss function used for the solver either 'square_loss' or 'kl_loss'. 'kl_loss' is not implemented yet and will raise an error. symmetric : bool, optional Either C1 and C2 are to be assumed symmetric or not. If let to its default None value, a symmetry test will be conducted. Expand Down Expand Up @@ -494,15 +490,20 @@ def semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p=None, loss_fun='square_lo if loss_fun == 'square_loss': gC1 = 2 * C1 * nx.outer(p, p) - 2 * nx.dot(T, nx.dot(C2, T.T)) gC2 = 2 * C2 * nx.outer(q, q) - 2 * nx.dot(T.T, nx.dot(C1, T)) if isinstance(alpha, int) or isinstance(alpha, float): srfgw_dist = nx.set_gradients(srfgw_dist, (C1, C2, M), (alpha * gC1, alpha * gC2, (1 - alpha) * T)) else: lin_term = nx.sum(T * M) srgw_term = (srfgw_dist - (1 - alpha) * lin_term) / alpha srfgw_dist = nx.set_gradients(srfgw_dist, (C1, C2, M, alpha), (alpha * gC1, alpha * gC2, (1 - alpha) * T, srgw_term - lin_term))
elif loss_fun == 'kl_loss': gC1 = nx.log(C1 + 1e-15) * nx.outer(p, p) - nx.dot(T, nx.dot(nx.log(C2 + 1e-15), T.T)) gC2 = nx.dot(T.T, nx.dot(C1, T)) / (C2 + 1e-15) + nx.outer(q, q)
if isinstance(alpha, int) or isinstance(alpha, float): srfgw_dist = nx.set_gradients(srfgw_dist, (C1, C2, M), (alpha * gC1, alpha * gC2, (1 - alpha) * T)) else: lin_term = nx.sum(T * M) srgw_term = (srfgw_dist - (1 - alpha) * lin_term) / alpha srfgw_dist = nx.set_gradients(srfgw_dist, (C1, C2, M, alpha), (alpha * gC1, alpha * gC2, (1 - alpha) * T, srgw_term - lin_term))
if log: return srfgw_dist, log_fgw Expand All @@ -511,7 +512,7 @@ def semirelaxed_fused_gromov_wasserstein2(M, C1, C2, p=None, loss_fun='square_lo

def solve_semirelaxed_gromov_linesearch(G, deltaG, cost_G, C1, C2, ones_p, M, reg, alpha_min=None, alpha_max=None, nx=None, **kwargs): M, reg, fC2t=None, alpha_min=None, alpha_max=None, nx=None, **kwargs): """ Solve the linesearch in the Conditional Gradient iterations for the semi-relaxed Gromov-Wasserstein divergence.
Expand All @@ -524,16 +525,22 @@ def solve_semirelaxed_gromov_linesearch(G, deltaG, cost_G, C1, C2, ones_p, Difference between the optimal map found by linearization in the FW algorithm and the value at a given iteration cost_G : float Value of the cost at `G` C1 : array-like (ns,ns) Structure matrix in the source domain. C2 : array-like (nt,nt) Structure matrix in the target domain. C1 : array-like (ns,ns), optional Transformed Structure matrix in the source domain. Note that for the 'square_loss' and 'kl_loss', we provide hC1 from ot.gromov.init_matrix_semirelaxed C2 : array-like (nt,nt), optional Transformed Structure matrix in the source domain. Note that for the 'square_loss' and 'kl_loss', we provide hC2 from ot.gromov.init_matrix_semirelaxed ones_p: array-like (ns,1) Array of ones of size ns M : array-like (ns,nt) Cost matrix between the features. reg : float Regularization parameter. fC2t: array-like (nt,nt), optional Transformed Structure matrix in the source domain. Note that for the 'square_loss' and 'kl_loss', we provide fC2t from ot.gromov.init_matrix_semirelaxed. If fC2t is not provided, it is by default fC2t corresponding to the 'square_loss'. alpha_min : float, optional Minimum value for alpha alpha_max : float, optional Expand Down Expand Up @@ -565,11 +572,14 @@ def solve_semirelaxed_gromov_linesearch(G, deltaG, cost_G, C1, C2, ones_p,
qG, qdeltaG = nx.sum(G, 0), nx.sum(deltaG, 0) dot = nx.dot(nx.dot(C1, deltaG), C2.T) C2t_square = C2.T ** 2 dot_qG = nx.dot(nx.outer(ones_p, qG), C2t_square) dot_qdeltaG = nx.dot(nx.outer(ones_p, qdeltaG), C2t_square) a = reg * nx.sum((dot_qdeltaG - 2 * dot) * deltaG) b = nx.sum(M * deltaG) + reg * (nx.sum((dot_qdeltaG - 2 * dot) * G) + nx.sum((dot_qG - 2 * nx.dot(nx.dot(C1, G), C2.T)) * deltaG)) if fC2t is None: fC2t = C2.T ** 2 dot_qG = nx.dot(nx.outer(ones_p, qG), fC2t) dot_qdeltaG = nx.dot(nx.outer(ones_p, qdeltaG), fC2t)
a = reg * nx.sum((dot_qdeltaG - dot) * deltaG) b = nx.sum(M * deltaG) + reg * (nx.sum((dot_qdeltaG - dot) * G) + nx.sum((dot_qG - nx.dot(nx.dot(C1, G), C2.T)) * deltaG))
alpha = solve_1d_linesearch_quad(a, b) if alpha_min is not None or alpha_max is not None: alpha = np.clip(alpha, alpha_min, alpha_max) Expand Down Expand Up @@ -620,7 +630,6 @@ def entropic_semirelaxed_gromov_wasserstein( If let to its default value None, uniform distribution is taken. loss_fun : str loss function used for the solver either 'square_loss' or 'kl_loss'. 'kl_loss' is not implemented yet and will raise an error. epsilon : float Regularization term >0 symmetric : bool, optional Expand Down Expand Up @@ -655,8 +664,6 @@ def entropic_semirelaxed_gromov_wasserstein( "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs" International Conference on Learning Representations (ICLR), 2022. """ if loss_fun == 'kl_loss': raise NotImplementedError() arr = [C1, C2] if p is not None: arr.append(list_to_array(p)) Expand Down Expand Up @@ -777,7 +784,6 @@ def entropic_semirelaxed_gromov_wasserstein2( If let to its default value None, uniform distribution is taken. loss_fun : str loss function used for the solver either 'square_loss' or 'kl_loss'. 'kl_loss' is not implemented yet and will raise an error. epsilon : float Regularization term >0 symmetric : bool, optional Expand Down Expand Up @@ -869,7 +875,6 @@ def entropic_semirelaxed_fused_gromov_wasserstein( If let to its default value None, uniform distribution is taken. loss_fun : str loss function used for the solver either 'square_loss' or 'kl_loss'. 'kl_loss' is not implemented yet and will raise an error. epsilon : float Regularization term >0 symmetric : bool, optional Expand Down Expand Up @@ -907,8 +912,6 @@ def entropic_semirelaxed_fused_gromov_wasserstein( "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs" International Conference on Learning Representations (ICLR), 2022. """ if loss_fun == 'kl_loss': raise NotImplementedError() arr = [M, C1, C2] if p is not None: arr.append(list_to_array(p)) Expand Down Expand Up @@ -1032,7 +1035,6 @@ def entropic_semirelaxed_fused_gromov_wasserstein2( If let to its default value None, uniform distribution is taken. loss_fun : str, optional loss function used for the solver either 'square_loss' or 'kl_loss'. 'kl_loss' is not implemented yet and will raise an error. epsilon : float Regularization term >0 symmetric : bool, optional Expand Down