diff --git a/README.md b/README.md index f8880a166..3e69af448 100644 --- a/README.md +++ b/README.md @@ -54,9 +54,10 @@ POT provides the following generic OT solvers: Large-scale Optimal Transport (semi-dual problem [18] and dual problem [19]) * [Sampled solver of Gromov Wasserstein](https://pythonot.github.io/auto_examples/gromov/plot_gromov.html) for large-scale problem with any loss functions [33] * Non regularized [free support Wasserstein barycenters](https://pythonot.github.io/auto_examples/barycenters/plot_free_support_barycenter.html) [20]. -* [One dimensional Unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_1D.html) with KL relaxation and [barycenter](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_barycenter_1D.html) [10, 25]. Also [exact unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_unbalanced_ot.html) with KL and quadratic regularization and the [regularization path of UOT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_regpath.html) [41] +* [One dimensional Unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_1D.html) with KL relaxation [73] and [barycenter](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT_barycenter_1D.html) [10, 25]. Also [exact unbalanced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_unbalanced_ot.html) with KL and quadratic regularization and the [regularization path of UOT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_regpath.html) [41] * [Partial Wasserstein and Gromov-Wasserstein](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_partial_wass_and_gromov.html) and [Partial Fused Gromov-Wasserstein](https://pythonot.github.io/auto_examples/gromov/plot_partial_fgw.html) (exact [29] and entropic [3] formulations). * [Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance.html) [31, 32] and Max-sliced Wasserstein [35] that can be used for gradient flows [36]. +* [Sliced Unbalanced OT and Unbalanced Sliced OT](https://pythonot.github.io/auto_examples/unbalanced-partial/plot_UOT.html) [82] * [Wasserstein distance on the circle](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_compute_wasserstein_circle.html) [44, 45] and [Spherical Sliced Wasserstein](https://pythonot.github.io/auto_examples/sliced-wasserstein/plot_variance_ssw.html) [46] @@ -367,7 +368,7 @@ Dictionary Learning](https://arxiv.org/pdf/2102.06555.pdf), International Confer [42] Delon, J., Gozlan, N., and Saint-Dizier, A. [Generalized Wasserstein barycenters between probability measures living on different subspaces](https://arxiv.org/pdf/2105.09755). arXiv preprint arXiv:2105.09755, 2021. -[43] Álvarez-Esteban, Pedro C., et al. [A fixed-point approach to barycenters in Wasserstein space.](https://arxiv.org/pdf/1511.05355.pdf) Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. +[43] Álvarez-Esteban, Pedro C., et al. [A fixed-point approach to barycenters in Wasserstein space.](https://arxiv.org/pdf/1511.05355.pdf) Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762. [44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. [Fast transport optimization for Monge costs on the circle.](https://arxiv.org/abs/0902.3527) SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258. @@ -449,5 +450,4 @@ Artificial Intelligence. [81] Xu, H., Luo, D., & Carin, L. (2019). [Scalable Gromov-Wasserstein learning for graph partitioning and matching](https://proceedings.neurips.cc/paper/2019/hash/6e62a992c676f611616097dbea8ea030-Abstract.html). Neural Information Processing Systems (NeurIPS). - -``` +[82] Bonet, C., Nadjahi, K., Séjourné, T., Fatras, K., & Courty, N. (2024). [Slicing Unbalanced Optimal Transport](https://openreview.net/forum?id=AjJTg5M0r8). Transactions on Machine Learning Research. diff --git a/RELEASES.md b/RELEASES.md index 1660913cd..d6abb1618 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -9,6 +9,8 @@ This new release adds support for sparse cost matrices and a new lazy EMD solver - Migrate backend from deprecated `scipy.sparse.coo_matrix` to modern `scipy.sparse.coo_array` (PR #782) - Geomloss function now handles both scalar and slice indices for i and j (PR #785) - Add support for sparse cost matrices in EMD solver (PR #778, Issue #397) +- Added UOT1D with Frank-Wolfe in `ot.unbalanced.uot_1d` (PR #765) +- Add Sliced UOT and Unbalanced Sliced OT in `ot/unbalanced/_sliced.py` (PR #765) #### Closed issues diff --git a/examples/unbalanced-partial/plot_UOT_1D.py b/examples/unbalanced-partial/plot_UOT_1D.py index ade4bbb0c..747be8ce3 100644 --- a/examples/unbalanced-partial/plot_UOT_1D.py +++ b/examples/unbalanced-partial/plot_UOT_1D.py @@ -9,6 +9,7 @@ """ # Author: Hicham Janati +# Clément Bonet # # License: MIT License @@ -19,6 +20,8 @@ import ot import ot.plot from ot.datasets import make_1D_gauss as gauss +import torch +import cvxpy as cp ############################################################################## # Generate data @@ -41,7 +44,6 @@ # loss matrix M = ot.dist(x.reshape((n, 1)), x.reshape((n, 1))) -M /= M.max() ############################################################################## @@ -62,25 +64,81 @@ ############################################################################## -# Solve Unbalanced Sinkhorn -# ------------------------- +# Solve Unbalanced OT with MM Unbalanced +# ----------------------------------- -# Sinkhorn +# %% MM Unbalanced -epsilon = 0.1 # entropy parameter alpha = 1.0 # Unbalanced KL relaxation parameter -Gs = ot.unbalanced.sinkhorn_unbalanced(a, b, M, epsilon, alpha, verbose=True) + +Gs, log = ot.unbalanced.mm_unbalanced(a, b, M / M.max(), alpha, verbose=False, log=True) pl.figure(3, figsize=(5, 5)) -ot.plot.plot1D_mat(a, b, Gs, "UOT matrix Sinkhorn") +ot.plot.plot1D_mat(a, b, Gs, "UOT plan") +pl.show() +pl.figure(4, figsize=(6.4, 3)) +pl.plot(x, a, "b", label="Source distribution") +pl.plot(x, b, "r", label="Target distribution") +pl.fill(x, Gs.sum(1), "b", alpha=0.5, label="Transported source") +pl.fill(x, Gs.sum(0), "r", alpha=0.5, label="Transported target") +pl.legend(loc="upper right") +pl.title("Distributions and transported mass for UOT") pl.show() +print("Mass of reweighted marginals:", Gs.sum()) +print("Unbalanced OT loss:", log["total_cost"] * M.max()) + + +############################################################################## +# Solve 1D UOT with Frank-Wolfe +# ----------------------------- + -# %% -# plot the transported mass +# %% 1D UOT with FW + + +alpha = M.max() # Unbalanced KL relaxation parameter + +a_reweighted, b_reweighted, loss = ot.unbalanced.uot_1d( + torch.tensor(x, dtype=torch.float64), + torch.tensor(x, dtype=torch.float64), + alpha, + u_weights=torch.tensor(a, dtype=torch.float64), + v_weights=torch.tensor(b, dtype=torch.float64), + p=2, + returnCost="total", +) + +pl.figure(4, figsize=(6.4, 3)) +pl.plot(x, a, "b", label="Source distribution") +pl.plot(x, b, "r", label="Target distribution") +pl.fill(x, a_reweighted, "b", alpha=0.5, label="Transported source") +pl.fill(x, b_reweighted, "r", alpha=0.5, label="Transported target") +pl.legend(loc="upper right") +pl.title("Distributions and transported mass for UOT") +pl.show() + +print("Mass of reweighted marginals:", a_reweighted.sum().item()) +print("Unbalanced OT loss:", loss.item()) + + +############################################################################## +# Solve Unbalanced Sinkhorn # ------------------------- +# %% Sinkhorn UOT + +# Sinkhorn + +epsilon = 0.1 # entropy parameter +alpha = 1.0 # Unbalanced KL relaxation parameter +Gs = ot.unbalanced.sinkhorn_unbalanced(a, b, M / M.max(), epsilon, alpha, verbose=True) + +pl.figure(3, figsize=(5, 5)) +ot.plot.plot1D_mat(a, b, Gs, "Entropic UOT plan") +pl.show() + pl.figure(4, figsize=(6.4, 3)) pl.plot(x, a, "b", label="Source distribution") pl.plot(x, b, "r", label="Target distribution") @@ -88,3 +146,6 @@ pl.fill(x, Gs.sum(0), "r", alpha=0.5, label="Transported target") pl.legend(loc="upper right") pl.title("Distributions and transported mass for UOT") +pl.show() + +print("Mass of reweighted marginals:", Gs.sum()) diff --git a/examples/unbalanced-partial/plot_UOT_sliced.py b/examples/unbalanced-partial/plot_UOT_sliced.py new file mode 100644 index 000000000..cd06ecb73 --- /dev/null +++ b/examples/unbalanced-partial/plot_UOT_sliced.py @@ -0,0 +1,274 @@ +# -*- coding: utf-8 -*- +""" +=================================== +Sliced Unbalanced optimal transport +=================================== + +This example illustrates the behavior of Sliced UOT versus +Unbalanced Sliced OT. + +The first one removes outliers on each slice while the second one +removes outliers of the original marginals. +""" + +# Author: Clément Bonet +# Nicolas Courty +# +# License: MIT License + +import numpy as np +import matplotlib.pylab as pl +import ot +import torch +import matplotlib.pyplot as plt +import matplotlib as mpl + +from sklearn.neighbors import KernelDensity + +############################################################################## +# Generate data +# ------------- + + +# %% parameters + +get_rot = lambda theta: np.array( + [[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]] +) + + +# regular distribution of Gaussians around a circle +def make_blobs_reg(n_samples, n_blobs, scale=0.5): + per_blob = int(n_samples / n_blobs) + result = np.random.randn(per_blob, 2) * scale + 5 + theta = (2 * np.pi) / (n_blobs) + for r in range(1, n_blobs): + new_blob = (np.random.randn(per_blob, 2) * scale + 5).dot(get_rot(theta * r)) + result = np.vstack((result, new_blob)) + return result + + +def make_blobs_random(n_samples, n_blobs, scale=0.5, offset=3): + per_blob = int(n_samples / n_blobs) + result = np.random.randn(per_blob, 2) * scale + np.random.randn(1, 2) * offset + for r in range(1, n_blobs): + new_blob = np.random.randn(per_blob, 2) * scale + np.random.randn(1, 2) * offset + result = np.vstack((result, new_blob)) + return result + + +def make_spiral(n_samples, noise=0.5): + n = np.sqrt(np.random.rand(n_samples, 1)) * 780 * (2 * np.pi) / 360 + d1x = -np.cos(n) * n + np.random.rand(n_samples, 1) * noise + d1y = np.sin(n) * n + np.random.rand(n_samples, 1) * noise + return np.array(np.hstack((d1x, d1y))) + + +n_samples = 500 +expe = "outlier" + +np.random.seed(42) + +nb_outliers = 200 +Xs = make_blobs_random(n_samples=n_samples, scale=0.2, n_blobs=1, offset=0) - 0.5 +Xs_outlier = make_blobs_random( + n_samples=nb_outliers, scale=0.05, n_blobs=1, offset=0 +) - [2, 0.5] + +Xs = np.vstack((Xs, Xs_outlier)) +Xt = make_blobs_random(n_samples=n_samples, scale=0.2, n_blobs=1, offset=0) + 1.5 +y = np.hstack(([0] * (n_samples + nb_outliers), [1] * n_samples)) +X = np.vstack((Xs, Xt)) + + +Xs_torch = torch.from_numpy(Xs).type(torch.float) +Xt_torch = torch.from_numpy(Xt).type(torch.float) + +p = 2 +num_proj = 180 + +a = torch.ones(Xs.shape[0], dtype=torch.float) +b = torch.ones(Xt.shape[0], dtype=torch.float) + +# construct projections +thetas = np.linspace(0, np.pi, num_proj) +dir = np.array([(np.cos(theta), np.sin(theta)) for theta in thetas]) +dir_torch = torch.from_numpy(dir).type(torch.float) + +Xps = (Xs_torch @ dir_torch.T).T # shape (n_projs, n) +Xpt = (Xt_torch @ dir_torch.T).T + +############################################################################## +# Compute SUOT and USOT +# ------------- + +# %% + +rho1_SUOT = 1 +rho2_SUOT = 1 +_, log = ot.unbalanced.sliced_unbalanced_ot( + Xs_torch, + Xt_torch, + (rho1_SUOT, rho2_SUOT), + a, + b, + num_proj, + p, + numItermax=10, + projections=dir_torch.T, + log=True, +) +A_SUOT, B_SUOT = log["a_reweighted"].T, log["b_reweighted"].T + + +rho1_USOT = 1 +rho2_USOT = 1 +A_USOT, B_USOT, _ = ot.unbalanced_sliced_ot( + Xs_torch, + Xt_torch, + (rho1_USOT, rho2_USOT), + a, + b, + num_proj, + p, + numItermax=10, + projections=dir_torch.T, +) + + +############################################################################## +# Utils plot +# ---------- + +# %% + + +def kde_sklearn(x, x_grid, weights=None, bandwidth=0.2, **kwargs): + """Kernel Density Estimation with Scikit-learn""" + kde_skl = KernelDensity(bandwidth=bandwidth, **kwargs) + if weights is not None: + kde_skl.fit(x[:, np.newaxis], sample_weight=weights) + else: + kde_skl.fit(x[:, np.newaxis]) + # score_samples() returns the log-likelihood of the samples + log_pdf = kde_skl.score_samples(x_grid[:, np.newaxis]) + return np.exp(log_pdf) + + +def plot_slices( + col, nb_slices, x_grid, Xps, Xpt, Xps_weights, Xpt_weights, method, rho1, rho2 +): + for i in range(nb_slices): + ax = plt.subplot2grid((nb_slices, 3), (i, col)) + if len(Xps_weights.shape) > 1: # SUOT + weights_src = Xps_weights[i * offset_degree, :].cpu().numpy() + weights_tgt = Xpt_weights[i * offset_degree, :].cpu().numpy() + else: # USOT + weights_src = Xps_weights.cpu().numpy() + weights_tgt = Xpt_weights.cpu().numpy() + + samples_src = Xps[i * offset_degree, :].cpu().numpy() + samples_tgt = Xpt[i * offset_degree, :].cpu().numpy() + + pdf_source = kde_sklearn(samples_src, x_grid, weights=weights_src, bandwidth=bw) + pdf_target = kde_sklearn(samples_tgt, x_grid, weights=weights_tgt, bandwidth=bw) + pdf_source_without_w = kde_sklearn(samples_src, x_grid, bandwidth=bw) + pdf_target_without_w = kde_sklearn(samples_tgt, x_grid, bandwidth=bw) + + ax.plot(x_grid, pdf_source, color=c2, alpha=0.8, lw=2) + ax.fill(x_grid, pdf_source_without_w, ec="grey", fc="grey", alpha=0.3) + ax.fill(x_grid, pdf_source, ec=c2, fc=c2, alpha=0.3) + + ax.plot(x_grid, pdf_target, color=c1, alpha=0.8, lw=2) + ax.fill(x_grid, pdf_target_without_w, ec="grey", fc="grey", alpha=0.3) + ax.fill(x_grid, pdf_target, ec=c2, fc=c1, alpha=0.3) + + ax.set_xlim(xlim_min, xlim_max) + + if col == 1: + ax.set_ylabel( + r"$\theta=${}$^o$".format(i * offset_degree), + color=colors[i], + fontsize=13, + ) + + ax.set_yticks([]) + ax.set_xticks([]) + + ax.set_xlabel( + r"{} $\rho_1={}$ $\rho_2={}$".format(method, rho1, rho2), fontsize=13 + ) + + +############################################################################## +# Plot reweighted distributions on several slices +# ------------- +# We plot the reweighted distributions on several slices. We see that for SUOT, +# the mode of outliers is kept of some slices (e.g. for :math:`\theta=120°`) while USOT +# is able to get rid of the outlier mode. + +# %% + +c1 = np.array(mpl.colors.to_rgb("red")) +c2 = np.array(mpl.colors.to_rgb("blue")) + +# define plotting grid +xlim_min = -3 +xlim_max = 3 +x_grid = np.linspace(xlim_min, xlim_max, 200) +bw = 0.05 + +# visu parameters +nb_slices = 3 # 4 +offset_degree = int(180 / nb_slices) + +delta_degree = np.pi / nb_slices +colors = plt.cm.Reds(np.linspace(0.3, 1, nb_slices)) + +X1 = np.array([-4, 0]) +X2 = np.array([4, 0]) + + +fig = plt.figure(figsize=(9, 3)) + +ax1 = plt.subplot2grid((nb_slices, 3), (0, 0), rowspan=nb_slices) + +for i in range(nb_slices): + R = get_rot(delta_degree * (-i)) + X1_r = X1.dot(R) + X2_r = X2.dot(R) + if i == 0: + ax1.plot( + [X1_r[0], X2_r[0]], + [X1_r[1], X2_r[1]], + color=colors[i], + alpha=0.8, + zorder=0, + label="Directions", + ) + else: + ax1.plot( + [X1_r[0], X2_r[0]], [X1_r[1], X2_r[1]], color=colors[i], alpha=0.8, zorder=0 + ) + +ax1.scatter(Xs[:, 0], Xs[:, 1], zorder=1, color=c2, label="Source data") +ax1.scatter(Xt[:, 0], Xt[:, 1], zorder=1, color=c1, label="Target data") +ax1.set_xlim([-3, 3]) +ax1.set_ylim([-3, 3]) +ax1.set_yticks([]) +ax1.set_xticks([]) +# ax1.legend(loc='best',fontsize=13) +ax1.set_xlabel("Original distributions", fontsize=13) + + +fig.subplots_adjust(hspace=0) +fig.subplots_adjust(wspace=0.15) + +plot_slices( + 1, nb_slices, x_grid, Xps, Xpt, A_SUOT, B_SUOT, "SUOT", rho1_SUOT, rho2_SUOT +) +plot_slices( + 2, nb_slices, x_grid, Xps, Xpt, A_USOT, B_USOT, "USOT", rho1_USOT, rho2_USOT +) + +plt.show() diff --git a/ignore-words.txt b/ignore-words.txt index 00c1f5edb..573400137 100644 --- a/ignore-words.txt +++ b/ignore-words.txt @@ -6,4 +6,5 @@ wass ccompiler ist lik -ges \ No newline at end of file +ges +mapp diff --git a/ot/__init__.py b/ot/__init__.py index 26f428aa1..75f17fed6 100644 --- a/ot/__init__.py +++ b/ot/__init__.py @@ -44,6 +44,7 @@ emd2_lazy, emd_1d, emd2_1d, + emd_1d_dual_backprop, wasserstein_1d, binary_search_circle, wasserstein_circle, @@ -51,7 +52,14 @@ linear_circular_ot, ) from .bregman import sinkhorn, sinkhorn2, barycenter -from .unbalanced import sinkhorn_unbalanced, barycenter_unbalanced, sinkhorn_unbalanced2 +from .unbalanced import ( + sinkhorn_unbalanced, + barycenter_unbalanced, + sinkhorn_unbalanced2, + uot_1d, + unbalanced_sliced_ot, + sliced_unbalanced_ot, +) from .da import sinkhorn_lpl1_mm from .sliced import ( sliced_wasserstein_distance, @@ -96,6 +104,8 @@ "toq", "gromov", "emd2_1d", + "emd_1d_dual", + "emd_1d_dual_backprop", "wasserstein_1d", "backend", "gaussian", @@ -110,6 +120,9 @@ "sinkhorn_unbalanced2", "sliced_wasserstein_distance", "sliced_wasserstein_sphere", + "uot_1d", + "unbalanced_sliced_ot", + "sliced_unbalanced_ot", "linear_sliced_wasserstein_sphere", "gromov_wasserstein", "gromov_wasserstein2", diff --git a/ot/backend.py b/ot/backend.py index 6b03f5cd1..72f296ddf 100644 --- a/ot/backend.py +++ b/ot/backend.py @@ -89,6 +89,7 @@ import os import time import warnings +import functools import numpy as np import scipy @@ -122,7 +123,27 @@ from jax.extend.backend import get_backend as _jax_get_backend jax_type = jax.numpy.ndarray - jax_new_version = float(".".join(jax.__version__.split(".")[1:])) > 4.24 + # jax_new_version = float(".".join(jax.__version__.split(".")[1:])) > 4.24 + jax_new_version = tuple([float(s) for s in jax.__version__.split(".")]) > ( + 0, + 4, + 24, + ) + + @jax.custom_jvp + def norm_1d_jax(z): + return jnp.abs(z) + + @norm_1d_jax.defjvp + def norm_1d_jax_jvp(primals, tangents): + (z,) = primals + z_is_zero = jnp.all(jnp.logical_not(z)) + clean_z = jnp.where(z_is_zero, jnp.ones_like(z), z) + primals, tangents = jax.jvp( + functools.partial(jnp.abs), (clean_z,), tangents + ) + return jnp.abs(z), jnp.where(z_is_zero, 0.0, tangents) + except ImportError: jax = False jax_type = float @@ -588,7 +609,7 @@ def flip(self, a, axis=None): """ raise NotImplementedError() - def clip(self, a, a_min, a_max): + def clip(self, a, a_min=None, a_max=None): """ Limits the values in a tensor. @@ -1015,7 +1036,7 @@ def eigh(self, a): """ raise NotImplementedError() - def kl_div(self, p, q, mass=False, eps=1e-16): + def kl_div(self, p, q, mass=False, eps=1e-16, axis=None): r""" Computes the (Generalized) Kullback-Leibler divergence. @@ -1145,6 +1166,22 @@ def slogdet(self, a): """ raise NotImplementedError() + def index_select(self, input, axis, index): + r""" + Returns a new tensor which indexes the input tensor along dimension dim using the entries in index. + + See: https://docs.pytorch.org/docs/stable/generated/torch.index_select.html + """ + raise NotImplementedError() + + def nonzero(self, input, as_tuple=False): + r""" + Returns a tensor containing the indices of all non-zero elements of input. + + See: https://docs.pytorch.org/docs/stable/generated/torch.nonzero.html + """ + raise NotImplementedError() + class NumpyBackend(Backend): """ @@ -1286,7 +1323,7 @@ def flip(self, a, axis=None): def outer(self, a, b): return np.outer(a, b) - def clip(self, a, a_min, a_max): + def clip(self, a, a_min=None, a_max=None): return np.clip(a, a_min, a_max) def repeat(self, a, repeats, axis=None): @@ -1463,10 +1500,10 @@ def sqrtm(self, a): def eigh(self, a): return np.linalg.eigh(a) - def kl_div(self, p, q, mass=False, eps=1e-16): - value = np.sum(p * np.log(p / q + eps)) + def kl_div(self, p, q, mass=False, eps=1e-16, axis=None): + value = np.sum(p * np.log(p / q + eps), axis=axis) if mass: - value = value + np.sum(q - p) + value = value + np.sum(q - p, axis=axis) return value def isfinite(self, a): @@ -1528,6 +1565,16 @@ def det(self, a): def slogdet(self, a): return np.linalg.slogdet(a) + def index_select(self, input, axis, index): + return np.take(input, index, axis) + + def nonzero(self, input, as_tuple=False): + if as_tuple: + return np.nonzero(input) + else: + L_tuple = np.nonzero(input) + return np.concatenate([t[None] for t in L_tuple], axis=0).T + _register_backend_implementation(NumpyBackend) @@ -1653,7 +1700,7 @@ def dot(self, a, b): return jnp.dot(a, b) def abs(self, a): - return jnp.abs(a) + return norm_1d_jax(a) def exp(self, a): return jnp.exp(a) @@ -1703,7 +1750,7 @@ def flip(self, a, axis=None): def outer(self, a, b): return jnp.outer(a, b) - def clip(self, a, a_min, a_max): + def clip(self, a, a_min=None, a_max=None): return jnp.clip(a, a_min, a_max) def repeat(self, a, repeats, axis=None): @@ -1898,10 +1945,10 @@ def sqrtm(self, a): def eigh(self, a): return jnp.linalg.eigh(a) - def kl_div(self, p, q, mass=False, eps=1e-16): - value = jnp.sum(p * jnp.log(p / q + eps)) + def kl_div(self, p, q, mass=False, eps=1e-16, axis=None): + value = jnp.sum(p * jnp.log(p / q + eps), axis=axis) if mass: - value = value + jnp.sum(q - p) + value = value + jnp.sum(q - p, axis=axis) return value def isfinite(self, a): @@ -1946,6 +1993,16 @@ def det(self, x): def slogdet(self, a): return jnp.linalg.slogdet(a) + def index_select(self, input, axis, index): + return jnp.take(input, index, axis) + + def nonzero(self, input, as_tuple=False): + if as_tuple: + return jnp.nonzero(input) + else: + L_tuple = jnp.nonzero(input) + return jnp.concatenate([t[None] for t in L_tuple], axis=0).T + if jax: # Only register jax backend if it is installed @@ -2200,7 +2257,7 @@ def flip(self, a, axis=None): def outer(self, a, b): return torch.outer(a, b) - def clip(self, a, a_min, a_max): + def clip(self, a, a_min=None, a_max=None): return torch.clamp(a, a_min, a_max) def repeat(self, a, repeats, axis=None): @@ -2489,10 +2546,10 @@ def sqrtm(self, a): def eigh(self, a): return torch.linalg.eigh(a) - def kl_div(self, p, q, mass=False, eps=1e-16): - value = torch.sum(p * torch.log(p / q + eps)) + def kl_div(self, p, q, mass=False, eps=1e-16, axis=None): + value = torch.sum(p * torch.log(p / q + eps), axis=axis) if mass: - value = value + torch.sum(q - p) + value = value + torch.sum(q - p, axis=axis) return value def isfinite(self, a): @@ -2540,6 +2597,12 @@ def det(self, x): def slogdet(self, a): return torch.linalg.slogdet(a) + def index_select(self, input, axis, index): + return torch.index_select(input, axis, index) + + def nonzero(self, input, as_tuple=False): + return torch.nonzero(input, as_tuple=as_tuple) + if torch: # Only register torch backend if it is installed @@ -2701,7 +2764,7 @@ def flip(self, a, axis=None): def outer(self, a, b): return cp.outer(a, b) - def clip(self, a, a_min, a_max): + def clip(self, a, a_min=None, a_max=None): return cp.clip(a, a_min, a_max) def repeat(self, a, repeats, axis=None): @@ -2915,10 +2978,10 @@ def sqrtm(self, a): def eigh(self, a): return cp.linalg.eigh(a) - def kl_div(self, p, q, mass=False, eps=1e-16): - value = cp.sum(p * cp.log(p / q + eps)) + def kl_div(self, p, q, mass=False, eps=1e-16, axis=None): + value = cp.sum(p * cp.log(p / q + eps), axis=axis) if mass: - value = value + cp.sum(q - p) + value = value + cp.sum(q - p, axis=axis) return value def isfinite(self, a): @@ -2963,6 +3026,16 @@ def det(self, x): def slogdet(self, a): return cp.linalg.slogdet(a) + def index_select(self, input, axis, index): + return cp.take(input, index, axis) + + def nonzero(self, input, as_tuple=False): + if as_tuple: + return cp.nonzero(input) + else: + L_tuple = cp.nonzero(input) + return cp.concatenate([t[None] for t in L_tuple], axis=0).T + if cp: # Only register cp backend if it is installed @@ -3135,7 +3208,7 @@ def flip(self, a, axis=None): def outer(self, a, b): return tnp.outer(a, b) - def clip(self, a, a_min, a_max): + def clip(self, a, a_min=None, a_max=None): return tnp.clip(a, a_min, a_max) def repeat(self, a, repeats, axis=None): @@ -3372,10 +3445,10 @@ def sqrtm(self, a): def eigh(self, a): return tf.linalg.eigh(a) - def kl_div(self, p, q, mass=False, eps=1e-16): - value = tnp.sum(p * tnp.log(p / q + eps)) + def kl_div(self, p, q, mass=False, eps=1e-16, axis=None): + value = tnp.sum(p * tnp.log(p / q + eps), axis=axis) if mass: - value = value + tnp.sum(q - p) + value = value + tnp.sum(q - p, axis=axis) return value def isfinite(self, a): @@ -3423,6 +3496,16 @@ def det(self, x): def slogdet(self, a): return tf.linalg.slogdet(a) + def index_select(self, input, axis, index): + return tf.gather(input, index, axis=axis) + + def nonzero(self, input, as_tuple=False): + if as_tuple: + return tf.where(input) + else: + indices = tf.where(input) + return tf.reshape(indices, (-1, indices.shape[-1])) + if tf: # Only register tensorflow backend if it is installed diff --git a/ot/lp/__init__.py b/ot/lp/__init__.py index 8e88d63c8..c9fa676c4 100644 --- a/ot/lp/__init__.py +++ b/ot/lp/__init__.py @@ -26,6 +26,10 @@ emd_1d, emd2_1d, wasserstein_1d, + emd_1d_dual_backprop, +) + +from .solver_circle import ( binary_search_circle, wasserstein_circle, semidiscrete_wasserstein2_unif_circle, @@ -43,6 +47,8 @@ "emd_1d", "emd2_1d", "wasserstein_1d", + "emd_1d_dual", + "emd_1d_dual_backprop", "generalized_free_support_barycenter", "binary_search_circle", "wasserstein_circle", diff --git a/ot/lp/_network_simplex.py b/ot/lp/_network_simplex.py index ec06298bc..1f376b707 100644 --- a/ot/lp/_network_simplex.py +++ b/ot/lp/_network_simplex.py @@ -44,31 +44,46 @@ def center_ot_dual(alpha0, beta0, a=None, b=None): Parameters ---------- - alpha0 : (ns,) numpy.ndarray, float64 + alpha0 : (ns, ...) numpy.ndarray, float64 Source dual potential - beta0 : (nt,) numpy.ndarray, float64 + beta0 : (nt, ...) numpy.ndarray, float64 Target dual potential - a : (ns,) numpy.ndarray, float64 + a : (ns, ...) numpy.ndarray, float64 Source histogram (uniform weight if empty list) - b : (nt,) numpy.ndarray, float64 + b : (nt, ...) numpy.ndarray, float64 Target histogram (uniform weight if empty list) Returns ------- - alpha : (ns,) numpy.ndarray, float64 + alpha : (ns, ...) numpy.ndarray, float64 Source centered dual potential - beta : (nt,) numpy.ndarray, float64 + beta : (nt, ...) numpy.ndarray, float64 Target centered dual potential """ + if a is not None and b is not None: + nx = get_backend(alpha0, beta0, a, b) + else: + nx = get_backend(alpha0, beta0) + + n = alpha0.shape[0] + m = beta0.shape[0] + # if no weights are provided, use uniform if a is None: - a = np.ones(alpha0.shape[0]) / alpha0.shape[0] + a = nx.full(alpha0.shape, 1.0 / n, type_as=alpha0) + elif a.ndim != alpha0.ndim: + a = nx.repeat(a[..., None], alpha0.shape[-1], -1) + if b is None: - b = np.ones(beta0.shape[0]) / beta0.shape[0] + b = nx.full(beta0.shape, 1.0 / m, type_as=beta0) + elif b.ndim != beta0.ndim: + b = nx.repeat(b[..., None], beta0.shape[-1], -1) # compute constant that balances the weighted sums of the duals - c = (b.dot(beta0) - a.dot(alpha0)) / (a.sum() + b.sum()) + ips = nx.sum(b * beta0, axis=0) - nx.sum(a * alpha0, axis=0) + denom = nx.sum(a, axis=0) + nx.sum(b, axis=0) + c = ips / denom # update duals alpha = alpha0 + c diff --git a/ot/lp/solver_1d.py b/ot/lp/solver_1d.py index 49e0c9c41..96ae35776 100644 --- a/ot/lp/solver_1d.py +++ b/ot/lp/solver_1d.py @@ -5,6 +5,7 @@ # Author: Remi Flamary # Author: Nicolas Courty +# Author: Clément Bonet # # License: MIT License @@ -14,9 +15,10 @@ from .emd_wrap import emd_1d_sorted from ..backend import get_backend from ..utils import list_to_array +from ._network_simplex import center_ot_dual -def quantile_function(qs, cws, xs): +def quantile_function(qs, cws, xs, return_index=False): r"""Computes the quantile function of an empirical distribution Parameters @@ -27,6 +29,7 @@ def quantile_function(qs, cws, xs): cumulative weights of the 1D empirical distribution, if batched, must be similar to xs xs: array-like, shape (n, ...) locations of the 1D empirical distribution, batched against the `xs.ndim - 1` first dimensions + return_index: bool Returns ------- @@ -43,8 +46,14 @@ def quantile_function(qs, cws, xs): else: cws = cws.T qs = qs.T - idx = nx.searchsorted(cws, qs).T - return nx.take_along_axis(xs, nx.clip(idx, 0, n - 1), axis=0) + # idx = nx.searchsorted(cws, qs).T + # return nx.take_along_axis(xs, nx.clip(idx, 0, n - 1), axis=0) + + idx = nx.clip(nx.searchsorted(cws, qs).T, 0, n - 1) + if return_index: + return nx.take_along_axis(xs, nx.clip(idx, 0, n - 1), axis=0), idx + else: + return nx.take_along_axis(xs, nx.clip(idx, 0, n - 1), axis=0) def wasserstein_1d( @@ -399,848 +408,109 @@ def emd2_1d( return cost -def roll_cols(M, shifts): - r""" - Utils functions which allow to shift the order of each row of a 2d matrix - - Parameters - ---------- - M : ndarray, shape (nr, nc) - Matrix to shift - shifts: int or ndarray, shape (nr,) - - Returns - ------- - Shifted array - - Examples - -------- - >>> M = np.array([[1,2,3],[4,5,6],[7,8,9]]) - >>> roll_cols(M, 2) - array([[2, 3, 1], - [5, 6, 4], - [8, 9, 7]]) - >>> roll_cols(M, np.array([[1],[2],[1]])) - array([[3, 1, 2], - [5, 6, 4], - [9, 7, 8]]) - - References - ---------- - https://stackoverflow.com/questions/66596699/how-to-shift-columns-or-rows-in-a-tensor-with-different-offsets-in-pytorch - """ - nx = get_backend(M) - - n_rows, n_cols = M.shape - - arange1 = nx.tile( - nx.reshape(nx.arange(n_cols, type_as=shifts), (1, n_cols)), (n_rows, 1) - ) - arange2 = (arange1 - shifts) % n_cols - - return nx.take_along_axis(M, arange2, 1) - - -def derivative_cost_on_circle(theta, u_values, v_values, u_cdf, v_cdf, p=2): - r"""Computes the left and right derivative of the cost (Equation (6.3) and (6.4) of [1]) - - Parameters - ---------- - theta: array-like, shape (n_batch, n) - Cuts on the circle - u_values: array-like, shape (n_batch, n) - locations of the first empirical distribution - v_values: array-like, shape (n_batch, n) - locations of the second empirical distribution - u_cdf: array-like, shape (n_batch, n) - cdf of the first empirical distribution - v_cdf: array-like, shape (n_batch, n) - cdf of the second empirical distribution - p: float, optional = 2 - Power p used for computing the Wasserstein distance - - Returns - ------- - dCp: array-like, shape (n_batch, 1) - The batched right derivative - dCm: array-like, shape (n_batch, 1) - The batched left derivative - - References - --------- - .. [44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258. - """ - nx = get_backend(theta, u_values, v_values, u_cdf, v_cdf) - - v_values = nx.copy(v_values) - - n = u_values.shape[-1] - m_batch, m = v_values.shape - - v_cdf_theta = v_cdf - (theta - nx.floor(theta)) - - mask_p = v_cdf_theta >= 0 - mask_n = v_cdf_theta < 0 - - v_values[mask_n] += nx.floor(theta)[mask_n] + 1 - v_values[mask_p] += nx.floor(theta)[mask_p] - - if nx.any(mask_n) and nx.any(mask_p): - v_cdf_theta[mask_n] += 1 - - v_cdf_theta2 = nx.copy(v_cdf_theta) - v_cdf_theta2[mask_n] = np.inf - shift = -nx.argmin(v_cdf_theta2, axis=-1) - - v_cdf_theta = roll_cols(v_cdf_theta, nx.reshape(shift, (-1, 1))) - v_values = roll_cols(v_values, nx.reshape(shift, (-1, 1))) - v_values = nx.concatenate( - [v_values, nx.reshape(v_values[:, 0], (-1, 1)) + 1], axis=1 - ) - - if nx.__name__ == "torch": - # this is to ensure the best performance for torch searchsorted - # and avoid a warning related to non-contiguous arrays - u_cdf = u_cdf.contiguous() - v_cdf_theta = v_cdf_theta.contiguous() - - # quantiles of F_u evaluated in F_v^\theta - u_index = nx.searchsorted(u_cdf, v_cdf_theta) - u_icdf_theta = nx.take_along_axis(u_values, nx.clip(u_index, 0, n - 1), -1) - - # Deal with 1 - u_cdfm = nx.concatenate([u_cdf, nx.reshape(u_cdf[:, 0], (-1, 1)) + 1], axis=1) - u_valuesm = nx.concatenate( - [u_values, nx.reshape(u_values[:, 0], (-1, 1)) + 1], axis=1 - ) - - if nx.__name__ == "torch": - # this is to ensure the best performance for torch searchsorted - # and avoid a warning related to non-contiguous arrays - u_cdfm = u_cdfm.contiguous() - v_cdf_theta = v_cdf_theta.contiguous() - - u_indexm = nx.searchsorted(u_cdfm, v_cdf_theta, side="right") - u_icdfm_theta = nx.take_along_axis(u_valuesm, nx.clip(u_indexm, 0, n), -1) - - dCp = nx.sum( - nx.power(nx.abs(u_icdf_theta - v_values[:, 1:]), p) - - nx.power(nx.abs(u_icdf_theta - v_values[:, :-1]), p), - axis=-1, - ) - - dCm = nx.sum( - nx.power(nx.abs(u_icdfm_theta - v_values[:, 1:]), p) - - nx.power(nx.abs(u_icdfm_theta - v_values[:, :-1]), p), - axis=-1, - ) - - return dCp.reshape(-1, 1), dCm.reshape(-1, 1) - - -def ot_cost_on_circle(theta, u_values, v_values, u_cdf, v_cdf, p): - r"""Computes the the cost (Equation (6.2) of [1]) - - Parameters - ---------- - theta: array-like, shape (n_batch, n) - Cuts on the circle - u_values: array-like, shape (n_batch, n) - locations of the first empirical distribution - v_values: array-like, shape (n_batch, n) - locations of the second empirical distribution - u_cdf: array-like, shape (n_batch, n) - cdf of the first empirical distribution - v_cdf: array-like, shape (n_batch, n) - cdf of the second empirical distribution - p: float, optional = 2 - Power p used for computing the Wasserstein distance - - Returns - ------- - ot_cost: array-like, shape (n_batch,) - OT cost evaluated at theta - - References - --------- - .. [44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258. - """ - nx = get_backend(theta, u_values, v_values, u_cdf, v_cdf) - - v_values = nx.copy(v_values) - - m_batch, m = v_values.shape - n_batch, n = u_values.shape - - v_cdf_theta = v_cdf - (theta - nx.floor(theta)) - - mask_p = v_cdf_theta >= 0 - mask_n = v_cdf_theta < 0 - - v_values[mask_n] += nx.floor(theta)[mask_n] + 1 - v_values[mask_p] += nx.floor(theta)[mask_p] - - if nx.any(mask_n) and nx.any(mask_p): - v_cdf_theta[mask_n] += 1 - - # Put negative values at the end - v_cdf_theta2 = nx.copy(v_cdf_theta) - v_cdf_theta2[mask_n] = np.inf - shift = -nx.argmin(v_cdf_theta2, axis=-1) - - v_cdf_theta = roll_cols(v_cdf_theta, nx.reshape(shift, (-1, 1))) - v_values = roll_cols(v_values, nx.reshape(shift, (-1, 1))) - v_values = nx.concatenate( - [v_values, nx.reshape(v_values[:, 0], (-1, 1)) + 1], axis=1 - ) - - # Compute absciss - cdf_axis = nx.sort(nx.concatenate((u_cdf, v_cdf_theta), -1), -1) - cdf_axis_pad = nx.zero_pad(cdf_axis, pad_width=[(0, 0), (1, 0)]) - - delta = cdf_axis_pad[..., 1:] - cdf_axis_pad[..., :-1] - - if nx.__name__ == "torch": - # this is to ensure the best performance for torch searchsorted - # and avoid a warning related to non-contiguous arrays - u_cdf = u_cdf.contiguous() - v_cdf_theta = v_cdf_theta.contiguous() - cdf_axis = cdf_axis.contiguous() - - # Compute icdf - u_index = nx.searchsorted(u_cdf, cdf_axis) - u_icdf = nx.take_along_axis(u_values, u_index.clip(0, n - 1), -1) - - v_values = nx.concatenate( - [v_values, nx.reshape(v_values[:, 0], (-1, 1)) + 1], axis=1 - ) - v_index = nx.searchsorted(v_cdf_theta, cdf_axis) - v_icdf = nx.take_along_axis(v_values, v_index.clip(0, m), -1) - - if p == 1: - ot_cost = nx.sum(delta * nx.abs(u_icdf - v_icdf), axis=-1) - else: - ot_cost = nx.sum(delta * nx.power(nx.abs(u_icdf - v_icdf), p), axis=-1) - - return ot_cost - - -def binary_search_circle( - u_values, - v_values, - u_weights=None, - v_weights=None, - p=1, - Lm=10, - Lp=10, - tm=-1, - tp=1, - eps=1e-6, - require_sort=True, - log=False, +def emd_1d_dual_backprop( + u_values, v_values, u_weights=None, v_weights=None, p=1, require_sort=True ): - r"""Computes the Wasserstein distance on the circle using the Binary search algorithm proposed in [44]. - Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`, - takes the value modulo 1. - If the values are on :math:`S^1\subset\mathbb{R}^2`, it is required to first find the coordinates - using e.g. the atan2 function. + r""" + Computes the 1 dimensional OT loss between two (batched) empirical + distributions .. math:: - W_p^p(u,v) = \inf_{\theta\in\mathbb{R}}\int_0^1 |F_u^{-1}(q) - (F_v-\theta)^{-1}(q)|^p\ \mathrm{d}q - - where: - - - :math:`F_u` and :math:`F_v` are respectively the cdfs of :math:`u` and :math:`v` + OT_{loss} = \int_0^1 |cdf_u^{-1}(q) - cdf_v^{-1}(q)|^p dq - For values :math:`x=(x_1,x_2)\in S^1`, it is required to first get their coordinates with + and returns the dual potentials and the loss, i.e. such that .. math:: - u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi} - - using e.g. ot.utils.get_coordinate_circle(x) + OT_{loss}(u,v) = \int f(x)\mathrm{d}u(x) + \int g(y)\mathrm{d}v(y). - The function runs on backend but tensorflow and jax are not supported. + We do so by backpropagating through the `wasserstein_1d` function. Thus, the function + only works in torch and jax. Parameters ---------- - u_values : ndarray, shape (n, ...) - samples in the source domain (coordinates on [0,1[) - v_values : ndarray, shape (n, ...) - samples in the target domain (coordinates on [0,1[) - u_weights : ndarray, shape (n, ...), optional - samples weights in the source domain - v_weights : ndarray, shape (n, ...), optional - samples weights in the target domain - p : float, optional (default=1) - Power p used for computing the Wasserstein distance - Lm : int, optional - Lower bound dC - Lp : int, optional - Upper bound dC - tm: float, optional - Lower bound theta - tp: float, optional - Upper bound theta - eps: float, optional - Stopping condition + u_values: array-like, shape (n, ...) + locations of the first empirical distribution + v_values: array-like, shape (m, ...) + locations of the second empirical distribution + u_weights: array-like, shape (n, ...), optional + weights of the first empirical distribution, if None then uniform weights are used + v_weights: array-like, shape (m, ...), optional + weights of the second empirical distribution, if None then uniform weights are used + p: int, optional + order of the ground metric used, should be at least 1, default is 1 require_sort: bool, optional - If True, sort the values. - log: bool, optional - If True, returns also the optimal theta + sort the distributions atoms locations, if False we will consider they have been sorted prior to being passed to + the function, default is True Returns ------- + f: array-like shape (n, ...) + First dual potential + g: array-like shape (m, ...) + Second dual potential loss: float/array-like, shape (...) - Batched cost associated to the optimal transportation - log: dict, optional - log dictionary returned only if log==True in parameters - - Examples - -------- - >>> u = np.array([[0.2,0.5,0.8]])%1 - >>> v = np.array([[0.4,0.5,0.7]])%1 - >>> binary_search_circle(u.T, v.T, p=1) - array([0.1]) - - References - ---------- - .. [44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258. - .. Matlab Code: https://users.mccme.ru/ansobol/otarie/software.html + the batched EMD """ - assert p >= 1, "The OT loss is only valid for p>=1, {p} was given".format(p=p) - if u_weights is not None and v_weights is not None: nx = get_backend(u_values, v_values, u_weights, v_weights) else: nx = get_backend(u_values, v_values) - n = u_values.shape[0] - m = v_values.shape[0] - - if len(u_values.shape) == 1: - u_values = nx.reshape(u_values, (n, 1)) - if len(v_values.shape) == 1: - v_values = nx.reshape(v_values, (m, 1)) - - if u_values.shape[1] != v_values.shape[1]: - raise ValueError( - "u and v must have the same number of batches {} and {} respectively given".format( - u_values.shape[1], v_values.shape[1] - ) - ) - - u_values = u_values % 1 - v_values = v_values % 1 - - if u_weights is None: - u_weights = nx.full(u_values.shape, 1.0 / n, type_as=u_values) - elif u_weights.ndim != u_values.ndim: - u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1) - if v_weights is None: - v_weights = nx.full(v_values.shape, 1.0 / m, type_as=v_values) - elif v_weights.ndim != v_values.ndim: - v_weights = nx.repeat(v_weights[..., None], v_values.shape[-1], -1) - - if require_sort: - u_sorter = nx.argsort(u_values, 0) - u_values = nx.take_along_axis(u_values, u_sorter, 0) - - v_sorter = nx.argsort(v_values, 0) - v_values = nx.take_along_axis(v_values, v_sorter, 0) - - u_weights = nx.take_along_axis(u_weights, u_sorter, 0) - v_weights = nx.take_along_axis(v_weights, v_sorter, 0) - - u_cdf = nx.cumsum(u_weights, 0).T - v_cdf = nx.cumsum(v_weights, 0).T - - u_values = u_values.T - v_values = v_values.T - - L = max(Lm, Lp) - - tm = tm * nx.reshape(nx.ones((u_values.shape[0],), type_as=u_values), (-1, 1)) - tm = nx.tile(tm, (1, m)) - tp = tp * nx.reshape(nx.ones((u_values.shape[0],), type_as=u_values), (-1, 1)) - tp = nx.tile(tp, (1, m)) - tc = (tm + tp) / 2 - - done = nx.zeros((u_values.shape[0], m)) - - cpt = 0 - while nx.any(1 - done): - cpt += 1 - - dCp, dCm = derivative_cost_on_circle(tc, u_values, v_values, u_cdf, v_cdf, p) - done = ((dCp * dCm) <= 0) * 1 - - mask = ((tp - tm) < eps / L) * (1 - done) - - if nx.any(mask): - # can probably be improved by computing only relevant values - dCptp, dCmtp = derivative_cost_on_circle( - tp, u_values, v_values, u_cdf, v_cdf, p - ) - dCptm, dCmtm = derivative_cost_on_circle( - tm, u_values, v_values, u_cdf, v_cdf, p - ) - Ctm = ot_cost_on_circle(tm, u_values, v_values, u_cdf, v_cdf, p).reshape( - -1, 1 - ) - Ctp = ot_cost_on_circle(tp, u_values, v_values, u_cdf, v_cdf, p).reshape( - -1, 1 - ) - - # Avoid warning raised when dCptm - dCmtp == 0, for which - # tc is not updated as mask_end is False, - # see Issue #738 - with warnings.catch_warnings(): - warnings.simplefilter("ignore", category=RuntimeWarning) - mask_end = mask * (nx.abs(dCptm - dCmtp) > 0.001) - tc[mask_end > 0] = ( - (Ctp - Ctm + tm * dCptm - tp * dCmtp) / (dCptm - dCmtp) - )[mask_end > 0] - done[nx.prod(mask, axis=-1) > 0] = 1 - elif nx.any(1 - done): - tm[((1 - mask) * (dCp < 0)) > 0] = tc[((1 - mask) * (dCp < 0)) > 0] - tp[((1 - mask) * (dCp >= 0)) > 0] = tc[((1 - mask) * (dCp >= 0)) > 0] - tc[((1 - mask) * (1 - done)) > 0] = ( - tm[((1 - mask) * (1 - done)) > 0] + tp[((1 - mask) * (1 - done)) > 0] - ) / 2 - - w = ot_cost_on_circle(nx.detach(tc), u_values, v_values, u_cdf, v_cdf, p) - - if log: - return w, {"optimal_theta": tc[:, 0]} - return w - - -def wasserstein1_circle( - u_values, v_values, u_weights=None, v_weights=None, require_sort=True -): - r"""Computes the 1-Wasserstein distance on the circle using the level median [45]. - Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`, - takes the value modulo 1. - If the values are on :math:`S^1\subset\mathbb{R}^2`, first find the coordinates - using e.g. the atan2 function. - The function runs on backend but tensorflow and jax are not supported. - - .. math:: - W_1(u,v) = \int_0^1 |F_u(t)-F_v(t)-LevMed(F_u-F_v)|\ \mathrm{d}t - - Parameters - ---------- - u_values : ndarray, shape (n, ...) - samples in the source domain (coordinates on [0,1[) - v_values : ndarray, shape (n, ...) - samples in the target domain (coordinates on [0,1[) - u_weights : ndarray, shape (n, ...), optional - samples weights in the source domain - v_weights : ndarray, shape (n, ...), optional - samples weights in the target domain - require_sort: bool, optional - If True, sort the values. - - Returns - ------- - loss: float/array-like, shape (...) - Batched cost associated to the optimal transportation - - Examples - -------- - >>> u = np.array([[0.2,0.5,0.8]])%1 - >>> v = np.array([[0.4,0.5,0.7]])%1 - >>> wasserstein1_circle(u.T, v.T) - array([0.1]) - - References - ---------- - .. [45] Hundrieser, Shayan, Marcel Klatt, and Axel Munk. "The statistics of circular optimal transport." Directional Statistics for Innovative Applications: A Bicentennial Tribute to Florence Nightingale. Singapore: Springer Nature Singapore, 2022. 57-82. - .. Code R: https://gitlab.gwdg.de/shundri/circularOT/-/tree/master/ - """ - - if u_weights is not None and v_weights is not None: - nx = get_backend(u_values, v_values, u_weights, v_weights) - else: - nx = get_backend(u_values, v_values) + assert nx.__name__ in ["torch", "jax"], "Function only valid in torch and jax" n = u_values.shape[0] m = v_values.shape[0] - if len(u_values.shape) == 1: - u_values = nx.reshape(u_values, (n, 1)) - if len(v_values.shape) == 1: - v_values = nx.reshape(v_values, (m, 1)) - - if u_values.shape[1] != v_values.shape[1]: - raise ValueError( - "u and v must have the same number of batchs {} and {} respectively given".format( - u_values.shape[1], v_values.shape[1] - ) - ) - - u_values = u_values % 1 - v_values = v_values % 1 - + # Init weights or broadcast if necessary if u_weights is None: u_weights = nx.full(u_values.shape, 1.0 / n, type_as=u_values) elif u_weights.ndim != u_values.ndim: u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1) + if v_weights is None: v_weights = nx.full(v_values.shape, 1.0 / m, type_as=v_values) elif v_weights.ndim != v_values.ndim: v_weights = nx.repeat(v_weights[..., None], v_values.shape[-1], -1) - if require_sort: - u_sorter = nx.argsort(u_values, 0) - u_values = nx.take_along_axis(u_values, u_sorter, 0) - - v_sorter = nx.argsort(v_values, 0) - v_values = nx.take_along_axis(v_values, v_sorter, 0) - - u_weights = nx.take_along_axis(u_weights, u_sorter, 0) - v_weights = nx.take_along_axis(v_weights, v_sorter, 0) - - # Code inspired from https://gitlab.gwdg.de/shundri/circularOT/-/tree/master/ - values_sorted, values_sorter = nx.sort2(nx.concatenate((u_values, v_values), 0), 0) - - cdf_diff = nx.cumsum( - nx.take_along_axis( - nx.concatenate((u_weights, -v_weights), 0), values_sorter, 0 - ), - 0, - ) - cdf_diff_sorted, cdf_diff_sorter = nx.sort2(cdf_diff, axis=0) - - values_sorted = nx.zero_pad(values_sorted, pad_width=[(0, 1), (0, 0)], value=1) - delta = values_sorted[1:, ...] - values_sorted[:-1, ...] - weight_sorted = nx.take_along_axis(delta, cdf_diff_sorter, 0) - - sum_weights = nx.cumsum(weight_sorted, axis=0) - 0.5 - sum_weights[sum_weights < 0] = np.inf - inds = nx.argmin(sum_weights, axis=0) - - levMed = nx.take_along_axis(cdf_diff_sorted, nx.reshape(inds, (1, -1)), 0) - - return nx.sum(delta * nx.abs(cdf_diff - levMed), axis=0) - - -def wasserstein_circle( - u_values, - v_values, - u_weights=None, - v_weights=None, - p=1, - Lm=10, - Lp=10, - tm=-1, - tp=1, - eps=1e-6, - require_sort=True, -): - r"""Computes the Wasserstein distance on the circle using either :ref:`[45] ` for p=1 or - the binary search algorithm proposed in :ref:`[44] ` otherwise. - Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`, - takes the value modulo 1. - If the values are on :math:`S^1\subset\mathbb{R}^2`, it requires to first find the coordinates - using e.g. the atan2 function. - - General loss returned: - - .. math:: - OT_{loss} = \inf_{\theta\in\mathbb{R}}\int_0^1 |cdf_u^{-1}(q) - (cdf_v-\theta)^{-1}(q)|^p\ \mathrm{d}q - - For p=1, [45] - - .. math:: - W_1(u,v) = \int_0^1 |F_u(t)-F_v(t)-LevMed(F_u-F_v)|\ \mathrm{d}t - - For values :math:`x=(x_1,x_2)\in S^1`, it is required to first get their coordinates with - - .. math:: - u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi} - - using e.g. ot.utils.get_coordinate_circle(x) - - The function runs on backend but tensorflow and jax are not supported. - - Parameters - ---------- - u_values : ndarray, shape (n, ...) - samples in the source domain (coordinates on [0,1[) - v_values : ndarray, shape (n, ...) - samples in the target domain (coordinates on [0,1[) - u_weights : ndarray, shape (n, ...), optional - samples weights in the source domain - v_weights : ndarray, shape (n, ...), optional - samples weights in the target domain - p : float, optional (default=1) - Power p used for computing the Wasserstein distance - Lm : int, optional - Lower bound dC. For p>1. - Lp : int, optional - Upper bound dC. For p>1. - tm: float, optional - Lower bound theta. For p>1. - tp: float, optional - Upper bound theta. For p>1. - eps: float, optional - Stopping condition. For p>1. - require_sort: bool, optional - If True, sort the values. - - Returns - ------- - loss: float/array-like, shape (...) - Batched cost associated to the optimal transportation - - Examples - -------- - >>> u = np.array([[0.2,0.5,0.8]])%1 - >>> v = np.array([[0.4,0.5,0.7]])%1 - >>> wasserstein_circle(u.T, v.T) - array([0.1]) - - - .. _references-wasserstein-circle: - References - ---------- - .. [44] Hundrieser, Shayan, Marcel Klatt, and Axel Munk. "The statistics of circular optimal transport." Directional Statistics for Innovative Applications: A Bicentennial Tribute to Florence Nightingale. Singapore: Springer Nature Singapore, 2022. 57-82. - .. [45] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258. - """ - assert p >= 1, "The OT loss is only valid for p>=1, {p} was given".format(p=p) - - return binary_search_circle( - u_values, - v_values, - u_weights, - v_weights, - p=p, - Lm=Lm, - Lp=Lp, - tm=tm, - tp=tp, - eps=eps, - require_sort=require_sort, - ) - - -def semidiscrete_wasserstein2_unif_circle(u_values, u_weights=None): - r"""Computes the closed-form for the 2-Wasserstein distance between samples and a uniform distribution on :math:`S^1` - Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`, - takes the value modulo 1. - If the values are on :math:`S^1\subset\mathbb{R}^2`, it is required to first find the coordinates - using e.g. the atan2 function. - - .. math:: - W_2^2(\mu_n, \nu) = \sum_{i=1}^n \alpha_i x_i^2 - \left(\sum_{i=1}^n \alpha_i x_i\right)^2 + \sum_{i=1}^n \alpha_i x_i \left(1-\alpha_i-2\sum_{k=1}^{i-1}\alpha_k\right) + \frac{1}{12} - - where: - - - :math:`\nu=\mathrm{Unif}(S^1)` and :math:`\mu_n = \sum_{i=1}^n \alpha_i \delta_{x_i}` - - For values :math:`x=(x_1,x_2)\in S^1`, it is required to first get their coordinates with - - .. math:: - u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi}, - - using e.g. ot.utils.get_coordinate_circle(x). - - Parameters - ---------- - u_values : ndarray, shape (n, ...) - Samples - u_weights : ndarray, shape (n, ...), optional - samples weights in the source domain - - Returns - ------- - loss: float/array-like, shape (...) - Batched cost associated to the optimal transportation - - Examples - -------- - >>> x0 = np.array([[0], [0.2], [0.4]]) - >>> semidiscrete_wasserstein2_unif_circle(x0) - array([0.02111111]) - - References - ---------- - .. [46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). Spherical sliced-wasserstein. International Conference on Learning Representations. - """ - - if u_weights is not None: - nx = get_backend(u_values, u_weights) - else: - nx = get_backend(u_values) - - n = u_values.shape[0] - - u_values = u_values % 1 - - if len(u_values.shape) == 1: - u_values = nx.reshape(u_values, (n, 1)) - - if u_weights is None: - u_weights = nx.full(u_values.shape, 1.0 / n, type_as=u_values) - elif u_weights.ndim != u_values.ndim: - u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1) - - u_values = nx.sort(u_values, 0) - u_cdf = nx.cumsum(u_weights, 0) - u_cdf = nx.zero_pad(u_cdf, [(1, 0), (0, 0)]) - - cpt1 = nx.sum(u_weights * u_values**2, axis=0) - u_mean = nx.sum(u_weights * u_values, axis=0) - - ns = 1 - u_weights - 2 * u_cdf[:-1] - cpt2 = nx.sum(u_values * u_weights * ns, axis=0) - - return cpt1 - u_mean**2 + cpt2 + 1 / 12 - - -def linear_circular_embedding(x, u_values, u_weights=None, require_sort=True): - r"""Returns the embedding :math:`\hat{\mu}(x)` of Linear Circular OT with reference - :math:`\eta=\mathrm{Unif}(S^1)` evaluated in :math:`x`. - - For any :math:`x\in [0,1[`, the embedding is given by (see :ref:`[78] `) - - .. math`` - \hat{\mu}(x) = F_{\mu}^{-1}\big(x - \int z\mathrm{d}\mu(z) + \frac12) - x. - - Parameters - ---------- - x : ndary, shape (m,) - Points in [0,1[ where to evaluate the embedding - u_values : ndarray, shape (n, ...) - samples in the source domain (coordinates on [0,1[) - u_weights : ndarray, shape (n, ...), optional - samples weights in the source domain - - Returns - ------- - embedding: ndarray of shape (m, ...) - Embedding evaluated at :math:`x` - - .. _references-lcot: - References - ---------- - .. [78] Martin, R. D., Medri, I., Bai, Y., Liu, X., Yan, K., Rohde, G. K., & Kolouri, S. (2024). LCOT: Linear Circular Optimal Transport. International Conference on Learning Representations. - """ - if u_weights is not None: - nx = get_backend(u_values, u_weights) - else: - nx = get_backend(u_values) - - n = u_values.shape[0] - u_values = u_values % 1 - - if len(u_values.shape) == 1: - u_values = nx.reshape(u_values, (n, 1)) - - if u_weights is None: - u_weights = nx.full(u_values.shape, 1.0 / n, type_as=u_values) - elif u_weights.ndim != u_values.ndim: - u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1) - - if require_sort: - u_sorter = nx.argsort(u_values, 0) - u_values = nx.take_along_axis(u_values, u_sorter, 0) - u_weights = nx.take_along_axis(u_weights, u_sorter, 0) - - u_cdf = nx.cumsum(u_weights, 0) - u_cdf = nx.zero_pad(u_cdf, [(1, 0), (0, 0)]) - - q_s = ( - x[:, None] - nx.sum(u_values * u_weights, axis=0)[None] + 0.5 - ) # shape (m, ...) - - u_quantiles = quantile_function(q_s % 1, u_cdf, u_values) - - return (u_quantiles - x[:, None]) % 1 - - -def linear_circular_ot(u_values, v_values=None, u_weights=None, v_weights=None): - r"""Computes the Linear Circular Optimal Transport distance from :ref:`[78] ` using :math:`\eta=\mathrm{Unif}(S^1)` - as reference measure. - Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`, - takes the value modulo 1. - If the values are on :math:`S^1\subset\mathbb{R}^2`, it is required to first find the coordinates - using e.g. the atan2 function. - - General loss returned: - - .. math:: - \mathrm{LCOT}_2^2(\mu, \nu) = \int_0^1 d_{S^1}\big(\hat{\mu}(t), \hat{\nu}(t)\big)^2\ \mathrm{d}t - - where :math:`\hat{\mu}(x)=F_{\mu}^{-1}(x-\int z\mathrm{d}\mu(z)+\frac12) - x` for all :math:`x\in [0,1[`, - and :math:`d_{S^1}(x,y)=\min(|x-y|, 1-|x-y|)` for :math:`x,y\in [0,1[`. - - Parameters - ---------- - u_values : ndarray, shape (n, ...) - samples in the source domain (coordinates on [0,1[) - v_values : ndarray, shape (n, ...), optional - samples in the target domain (coordinates on [0,1[), if None, compute distance against uniform distribution - u_weights : ndarray, shape (n, ...), optional - samples weights in the source domain - v_weights : ndarray, shape (n, ...), optional - samples weights in the target domain - - Returns - ------- - loss: float/array-like, shape (...) - Batched cost associated to the linear optimal transportation - - Examples - -------- - >>> u = np.array([[0.2,0.5,0.8]])%1 - >>> v = np.array([[0.4,0.5,0.7]])%1 - >>> linear_circular_ot(u.T, v.T) - array([0.0127]) - - - .. _references-lcot: - References - ---------- - .. [78] Martin, R. D., Medri, I., Bai, Y., Liu, X., Yan, K., Rohde, G. K., & Kolouri, S. (2024). LCOT: Linear Circular Optimal Transport. International Conference on Learning Representations. - """ - if u_weights is not None: - nx = get_backend(u_values, u_weights) - else: - nx = get_backend(u_values) - - n = u_values.shape[0] - u_values = u_values % 1 - - if len(u_values.shape) == 1: - u_values = nx.reshape(u_values, (n, 1)) - - if u_weights is None: - u_weights = nx.full(u_values.shape, 1.0 / n, type_as=u_values) - elif u_weights.ndim != u_values.ndim: - u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1) - - unif_s1 = nx.linspace(0, 1, 101, type_as=u_values)[:-1] + if nx.__name__ == "torch": + u_weights_diff = nx.copy(u_weights) + v_weights_diff = nx.copy(v_weights) + + u_weights_diff.requires_grad_(True) + v_weights_diff.requires_grad_(True) + + cost_output = wasserstein_1d( + u_values, + v_values, + u_weights_diff, + v_weights_diff, + p=p, + require_sort=require_sort, + ) + loss = cost_output.sum() + loss.backward() + + f, g = center_ot_dual( + u_weights_diff.grad.detach(), + v_weights_diff.grad.detach(), + u_weights, + v_weights, + ) - emb_u = linear_circular_embedding(unif_s1, u_values, u_weights) + return f, g, cost_output.detach() # value can not be backward anymore + elif nx.__name__ == "jax": + import jax - if v_values is None: - dist_u = nx.minimum(nx.abs(emb_u), 1 - nx.abs(emb_u)) - return nx.mean(dist_u**2, axis=0) - else: - m = v_values.shape[0] - if len(v_values.shape) == 1: - v_values = nx.reshape(v_values, (m, 1)) + def ot_1d(a, b): + return wasserstein_1d( + u_values, v_values, a, b, p=p, require_sort=require_sort + ).sum() - if u_values.shape[1] != v_values.shape[1]: - raise ValueError( - "u and v must have the same number of batchs {} and {} respectively given".format( - u_values.shape[1], v_values.shape[1] - ) - ) + f, g = jax.grad(ot_1d, argnums=[0, 1])(u_weights, v_weights) - emb_v = linear_circular_embedding(unif_s1, v_values, v_weights) + cost_output = wasserstein_1d( + u_values, v_values, u_weights, v_weights, p=p, require_sort=require_sort + ) - dist_uv = nx.minimum(nx.abs(emb_u - emb_v), 1 - nx.abs(emb_u - emb_v)) - return nx.mean(dist_uv**2, axis=0) + f, g = center_ot_dual(f, g, u_weights, v_weights) + return f, g, cost_output diff --git a/ot/lp/solver_circle.py b/ot/lp/solver_circle.py new file mode 100644 index 000000000..8fcdef49e --- /dev/null +++ b/ot/lp/solver_circle.py @@ -0,0 +1,861 @@ +# -*- coding: utf-8 -*- +""" +Exact solvers for the 1D Wasserstein distance using cvxopt +""" + +# Author: Clément Bonet +# +# License: MIT License + +import numpy as np +import warnings + +from ..backend import get_backend +from .solver_1d import quantile_function + + +def roll_cols(M, shifts): + r""" + Utils functions which allow to shift the order of each row of a 2d matrix + + Parameters + ---------- + M : ndarray, shape (nr, nc) + Matrix to shift + shifts: int or ndarray, shape (nr,) + + Returns + ------- + Shifted array + + Examples + -------- + >>> M = np.array([[1,2,3],[4,5,6],[7,8,9]]) + >>> roll_cols(M, 2) + array([[2, 3, 1], + [5, 6, 4], + [8, 9, 7]]) + >>> roll_cols(M, np.array([[1],[2],[1]])) + array([[3, 1, 2], + [5, 6, 4], + [9, 7, 8]]) + + References + ---------- + https://stackoverflow.com/questions/66596699/how-to-shift-columns-or-rows-in-a-tensor-with-different-offsets-in-pytorch + """ + nx = get_backend(M) + + n_rows, n_cols = M.shape + + arange1 = nx.tile( + nx.reshape(nx.arange(n_cols, type_as=shifts), (1, n_cols)), (n_rows, 1) + ) + arange2 = (arange1 - shifts) % n_cols + + return nx.take_along_axis(M, arange2, 1) + + +def derivative_cost_on_circle(theta, u_values, v_values, u_cdf, v_cdf, p=2): + r"""Computes the left and right derivative of the cost (Equation (6.3) and (6.4) of [1]) + + Parameters + ---------- + theta: array-like, shape (n_batch, n) + Cuts on the circle + u_values: array-like, shape (n_batch, n) + locations of the first empirical distribution + v_values: array-like, shape (n_batch, n) + locations of the second empirical distribution + u_cdf: array-like, shape (n_batch, n) + cdf of the first empirical distribution + v_cdf: array-like, shape (n_batch, n) + cdf of the second empirical distribution + p: float, optional = 2 + Power p used for computing the Wasserstein distance + + Returns + ------- + dCp: array-like, shape (n_batch, 1) + The batched right derivative + dCm: array-like, shape (n_batch, 1) + The batched left derivative + + References + --------- + .. [44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258. + """ + nx = get_backend(theta, u_values, v_values, u_cdf, v_cdf) + + v_values = nx.copy(v_values) + + n = u_values.shape[-1] + m_batch, m = v_values.shape + + v_cdf_theta = v_cdf - (theta - nx.floor(theta)) + + mask_p = v_cdf_theta >= 0 + mask_n = v_cdf_theta < 0 + + v_values[mask_n] += nx.floor(theta)[mask_n] + 1 + v_values[mask_p] += nx.floor(theta)[mask_p] + + if nx.any(mask_n) and nx.any(mask_p): + v_cdf_theta[mask_n] += 1 + + v_cdf_theta2 = nx.copy(v_cdf_theta) + v_cdf_theta2[mask_n] = np.inf + shift = -nx.argmin(v_cdf_theta2, axis=-1) + + v_cdf_theta = roll_cols(v_cdf_theta, nx.reshape(shift, (-1, 1))) + v_values = roll_cols(v_values, nx.reshape(shift, (-1, 1))) + v_values = nx.concatenate( + [v_values, nx.reshape(v_values[:, 0], (-1, 1)) + 1], axis=1 + ) + + if nx.__name__ == "torch": + # this is to ensure the best performance for torch searchsorted + # and avoid a warning related to non-contiguous arrays + u_cdf = u_cdf.contiguous() + v_cdf_theta = v_cdf_theta.contiguous() + + # quantiles of F_u evaluated in F_v^\theta + u_index = nx.searchsorted(u_cdf, v_cdf_theta) + u_icdf_theta = nx.take_along_axis(u_values, nx.clip(u_index, 0, n - 1), -1) + + # Deal with 1 + u_cdfm = nx.concatenate([u_cdf, nx.reshape(u_cdf[:, 0], (-1, 1)) + 1], axis=1) + u_valuesm = nx.concatenate( + [u_values, nx.reshape(u_values[:, 0], (-1, 1)) + 1], axis=1 + ) + + if nx.__name__ == "torch": + # this is to ensure the best performance for torch searchsorted + # and avoid a warning related to non-contiguous arrays + u_cdfm = u_cdfm.contiguous() + v_cdf_theta = v_cdf_theta.contiguous() + + u_indexm = nx.searchsorted(u_cdfm, v_cdf_theta, side="right") + u_icdfm_theta = nx.take_along_axis(u_valuesm, nx.clip(u_indexm, 0, n), -1) + + dCp = nx.sum( + nx.power(nx.abs(u_icdf_theta - v_values[:, 1:]), p) + - nx.power(nx.abs(u_icdf_theta - v_values[:, :-1]), p), + axis=-1, + ) + + dCm = nx.sum( + nx.power(nx.abs(u_icdfm_theta - v_values[:, 1:]), p) + - nx.power(nx.abs(u_icdfm_theta - v_values[:, :-1]), p), + axis=-1, + ) + + return dCp.reshape(-1, 1), dCm.reshape(-1, 1) + + +def ot_cost_on_circle(theta, u_values, v_values, u_cdf, v_cdf, p): + r"""Computes the the cost (Equation (6.2) of [1]) + + Parameters + ---------- + theta: array-like, shape (n_batch, n) + Cuts on the circle + u_values: array-like, shape (n_batch, n) + locations of the first empirical distribution + v_values: array-like, shape (n_batch, n) + locations of the second empirical distribution + u_cdf: array-like, shape (n_batch, n) + cdf of the first empirical distribution + v_cdf: array-like, shape (n_batch, n) + cdf of the second empirical distribution + p: float, optional = 2 + Power p used for computing the Wasserstein distance + + Returns + ------- + ot_cost: array-like, shape (n_batch,) + OT cost evaluated at theta + + References + --------- + .. [44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258. + """ + nx = get_backend(theta, u_values, v_values, u_cdf, v_cdf) + + v_values = nx.copy(v_values) + + m_batch, m = v_values.shape + n_batch, n = u_values.shape + + v_cdf_theta = v_cdf - (theta - nx.floor(theta)) + + mask_p = v_cdf_theta >= 0 + mask_n = v_cdf_theta < 0 + + v_values[mask_n] += nx.floor(theta)[mask_n] + 1 + v_values[mask_p] += nx.floor(theta)[mask_p] + + if nx.any(mask_n) and nx.any(mask_p): + v_cdf_theta[mask_n] += 1 + + # Put negative values at the end + v_cdf_theta2 = nx.copy(v_cdf_theta) + v_cdf_theta2[mask_n] = np.inf + shift = -nx.argmin(v_cdf_theta2, axis=-1) + + v_cdf_theta = roll_cols(v_cdf_theta, nx.reshape(shift, (-1, 1))) + v_values = roll_cols(v_values, nx.reshape(shift, (-1, 1))) + v_values = nx.concatenate( + [v_values, nx.reshape(v_values[:, 0], (-1, 1)) + 1], axis=1 + ) + + # Compute absciss + cdf_axis = nx.sort(nx.concatenate((u_cdf, v_cdf_theta), -1), -1) + cdf_axis_pad = nx.zero_pad(cdf_axis, pad_width=[(0, 0), (1, 0)]) + + delta = cdf_axis_pad[..., 1:] - cdf_axis_pad[..., :-1] + + if nx.__name__ == "torch": + # this is to ensure the best performance for torch searchsorted + # and avoid a warning related to non-contiguous arrays + u_cdf = u_cdf.contiguous() + v_cdf_theta = v_cdf_theta.contiguous() + cdf_axis = cdf_axis.contiguous() + + # Compute icdf + u_index = nx.searchsorted(u_cdf, cdf_axis) + u_icdf = nx.take_along_axis(u_values, u_index.clip(0, n - 1), -1) + + v_values = nx.concatenate( + [v_values, nx.reshape(v_values[:, 0], (-1, 1)) + 1], axis=1 + ) + v_index = nx.searchsorted(v_cdf_theta, cdf_axis) + v_icdf = nx.take_along_axis(v_values, v_index.clip(0, m), -1) + + if p == 1: + ot_cost = nx.sum(delta * nx.abs(u_icdf - v_icdf), axis=-1) + else: + ot_cost = nx.sum(delta * nx.power(nx.abs(u_icdf - v_icdf), p), axis=-1) + + return ot_cost + + +def binary_search_circle( + u_values, + v_values, + u_weights=None, + v_weights=None, + p=1, + Lm=10, + Lp=10, + tm=-1, + tp=1, + eps=1e-6, + require_sort=True, + log=False, +): + r"""Computes the Wasserstein distance on the circle using the Binary search algorithm proposed in [44]. + Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`, + takes the value modulo 1. + If the values are on :math:`S^1\subset\mathbb{R}^2`, it is required to first find the coordinates + using e.g. the atan2 function. + + .. math:: + W_p^p(u,v) = \inf_{\theta\in\mathbb{R}}\int_0^1 |F_u^{-1}(q) - (F_v-\theta)^{-1}(q)|^p\ \mathrm{d}q + + where: + + - :math:`F_u` and :math:`F_v` are respectively the cdfs of :math:`u` and :math:`v` + + For values :math:`x=(x_1,x_2)\in S^1`, it is required to first get their coordinates with + + .. math:: + u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi} + + using e.g. ot.utils.get_coordinate_circle(x) + + The function runs on backend but tensorflow and jax are not supported. + + Parameters + ---------- + u_values : ndarray, shape (n, ...) + samples in the source domain (coordinates on [0,1[) + v_values : ndarray, shape (n, ...) + samples in the target domain (coordinates on [0,1[) + u_weights : ndarray, shape (n, ...), optional + samples weights in the source domain + v_weights : ndarray, shape (n, ...), optional + samples weights in the target domain + p : float, optional (default=1) + Power p used for computing the Wasserstein distance + Lm : int, optional + Lower bound dC + Lp : int, optional + Upper bound dC + tm: float, optional + Lower bound theta + tp: float, optional + Upper bound theta + eps: float, optional + Stopping condition + require_sort: bool, optional + If True, sort the values. + log: bool, optional + If True, returns also the optimal theta + + Returns + ------- + loss: float/array-like, shape (...) + Batched cost associated to the optimal transportation + log: dict, optional + log dictionary returned only if log==True in parameters + + Examples + -------- + >>> u = np.array([[0.2,0.5,0.8]])%1 + >>> v = np.array([[0.4,0.5,0.7]])%1 + >>> binary_search_circle(u.T, v.T, p=1) + array([0.1]) + + References + ---------- + .. [44] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258. + .. Matlab Code: https://users.mccme.ru/ansobol/otarie/software.html + """ + assert p >= 1, "The OT loss is only valid for p>=1, {p} was given".format(p=p) + + if u_weights is not None and v_weights is not None: + nx = get_backend(u_values, v_values, u_weights, v_weights) + else: + nx = get_backend(u_values, v_values) + + n = u_values.shape[0] + m = v_values.shape[0] + + if len(u_values.shape) == 1: + u_values = nx.reshape(u_values, (n, 1)) + if len(v_values.shape) == 1: + v_values = nx.reshape(v_values, (m, 1)) + + if u_values.shape[1] != v_values.shape[1]: + raise ValueError( + "u and v must have the same number of batches {} and {} respectively given".format( + u_values.shape[1], v_values.shape[1] + ) + ) + + u_values = u_values % 1 + v_values = v_values % 1 + + if u_weights is None: + u_weights = nx.full(u_values.shape, 1.0 / n, type_as=u_values) + elif u_weights.ndim != u_values.ndim: + u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1) + if v_weights is None: + v_weights = nx.full(v_values.shape, 1.0 / m, type_as=v_values) + elif v_weights.ndim != v_values.ndim: + v_weights = nx.repeat(v_weights[..., None], v_values.shape[-1], -1) + + if require_sort: + u_sorter = nx.argsort(u_values, 0) + u_values = nx.take_along_axis(u_values, u_sorter, 0) + + v_sorter = nx.argsort(v_values, 0) + v_values = nx.take_along_axis(v_values, v_sorter, 0) + + u_weights = nx.take_along_axis(u_weights, u_sorter, 0) + v_weights = nx.take_along_axis(v_weights, v_sorter, 0) + + u_cdf = nx.cumsum(u_weights, 0).T + v_cdf = nx.cumsum(v_weights, 0).T + + u_values = u_values.T + v_values = v_values.T + + L = max(Lm, Lp) + + tm = tm * nx.reshape(nx.ones((u_values.shape[0],), type_as=u_values), (-1, 1)) + tm = nx.tile(tm, (1, m)) + tp = tp * nx.reshape(nx.ones((u_values.shape[0],), type_as=u_values), (-1, 1)) + tp = nx.tile(tp, (1, m)) + tc = (tm + tp) / 2 + + done = nx.zeros((u_values.shape[0], m)) + + cpt = 0 + while nx.any(1 - done): + cpt += 1 + + dCp, dCm = derivative_cost_on_circle(tc, u_values, v_values, u_cdf, v_cdf, p) + done = ((dCp * dCm) <= 0) * 1 + + mask = ((tp - tm) < eps / L) * (1 - done) + + if nx.any(mask): + # can probably be improved by computing only relevant values + dCptp, dCmtp = derivative_cost_on_circle( + tp, u_values, v_values, u_cdf, v_cdf, p + ) + dCptm, dCmtm = derivative_cost_on_circle( + tm, u_values, v_values, u_cdf, v_cdf, p + ) + Ctm = ot_cost_on_circle(tm, u_values, v_values, u_cdf, v_cdf, p).reshape( + -1, 1 + ) + Ctp = ot_cost_on_circle(tp, u_values, v_values, u_cdf, v_cdf, p).reshape( + -1, 1 + ) + + # Avoid warning raised when dCptm - dCmtp == 0, for which + # tc is not updated as mask_end is False, + # see Issue #738 + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=RuntimeWarning) + mask_end = mask * (nx.abs(dCptm - dCmtp) > 0.001) + tc[mask_end > 0] = ( + (Ctp - Ctm + tm * dCptm - tp * dCmtp) / (dCptm - dCmtp) + )[mask_end > 0] + done[nx.prod(mask, axis=-1) > 0] = 1 + elif nx.any(1 - done): + tm[((1 - mask) * (dCp < 0)) > 0] = tc[((1 - mask) * (dCp < 0)) > 0] + tp[((1 - mask) * (dCp >= 0)) > 0] = tc[((1 - mask) * (dCp >= 0)) > 0] + tc[((1 - mask) * (1 - done)) > 0] = ( + tm[((1 - mask) * (1 - done)) > 0] + tp[((1 - mask) * (1 - done)) > 0] + ) / 2 + + w = ot_cost_on_circle(nx.detach(tc), u_values, v_values, u_cdf, v_cdf, p) + + if log: + return w, {"optimal_theta": tc[:, 0]} + return w + + +def wasserstein1_circle( + u_values, v_values, u_weights=None, v_weights=None, require_sort=True +): + r"""Computes the 1-Wasserstein distance on the circle using the level median [45]. + Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`, + takes the value modulo 1. + If the values are on :math:`S^1\subset\mathbb{R}^2`, first find the coordinates + using e.g. the atan2 function. + The function runs on backend but tensorflow and jax are not supported. + + .. math:: + W_1(u,v) = \int_0^1 |F_u(t)-F_v(t)-LevMed(F_u-F_v)|\ \mathrm{d}t + + Parameters + ---------- + u_values : ndarray, shape (n, ...) + samples in the source domain (coordinates on [0,1[) + v_values : ndarray, shape (n, ...) + samples in the target domain (coordinates on [0,1[) + u_weights : ndarray, shape (n, ...), optional + samples weights in the source domain + v_weights : ndarray, shape (n, ...), optional + samples weights in the target domain + require_sort: bool, optional + If True, sort the values. + + Returns + ------- + loss: float/array-like, shape (...) + Batched cost associated to the optimal transportation + + Examples + -------- + >>> u = np.array([[0.2,0.5,0.8]])%1 + >>> v = np.array([[0.4,0.5,0.7]])%1 + >>> wasserstein1_circle(u.T, v.T) + array([0.1]) + + References + ---------- + .. [45] Hundrieser, Shayan, Marcel Klatt, and Axel Munk. "The statistics of circular optimal transport." Directional Statistics for Innovative Applications: A Bicentennial Tribute to Florence Nightingale. Singapore: Springer Nature Singapore, 2022. 57-82. + .. Code R: https://gitlab.gwdg.de/shundri/circularOT/-/tree/master/ + """ + + if u_weights is not None and v_weights is not None: + nx = get_backend(u_values, v_values, u_weights, v_weights) + else: + nx = get_backend(u_values, v_values) + + n = u_values.shape[0] + m = v_values.shape[0] + + if len(u_values.shape) == 1: + u_values = nx.reshape(u_values, (n, 1)) + if len(v_values.shape) == 1: + v_values = nx.reshape(v_values, (m, 1)) + + if u_values.shape[1] != v_values.shape[1]: + raise ValueError( + "u and v must have the same number of batchs {} and {} respectively given".format( + u_values.shape[1], v_values.shape[1] + ) + ) + + u_values = u_values % 1 + v_values = v_values % 1 + + if u_weights is None: + u_weights = nx.full(u_values.shape, 1.0 / n, type_as=u_values) + elif u_weights.ndim != u_values.ndim: + u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1) + if v_weights is None: + v_weights = nx.full(v_values.shape, 1.0 / m, type_as=v_values) + elif v_weights.ndim != v_values.ndim: + v_weights = nx.repeat(v_weights[..., None], v_values.shape[-1], -1) + + if require_sort: + u_sorter = nx.argsort(u_values, 0) + u_values = nx.take_along_axis(u_values, u_sorter, 0) + + v_sorter = nx.argsort(v_values, 0) + v_values = nx.take_along_axis(v_values, v_sorter, 0) + + u_weights = nx.take_along_axis(u_weights, u_sorter, 0) + v_weights = nx.take_along_axis(v_weights, v_sorter, 0) + + # Code inspired from https://gitlab.gwdg.de/shundri/circularOT/-/tree/master/ + values_sorted, values_sorter = nx.sort2(nx.concatenate((u_values, v_values), 0), 0) + + cdf_diff = nx.cumsum( + nx.take_along_axis( + nx.concatenate((u_weights, -v_weights), 0), values_sorter, 0 + ), + 0, + ) + cdf_diff_sorted, cdf_diff_sorter = nx.sort2(cdf_diff, axis=0) + + values_sorted = nx.zero_pad(values_sorted, pad_width=[(0, 1), (0, 0)], value=1) + delta = values_sorted[1:, ...] - values_sorted[:-1, ...] + weight_sorted = nx.take_along_axis(delta, cdf_diff_sorter, 0) + + sum_weights = nx.cumsum(weight_sorted, axis=0) - 0.5 + sum_weights[sum_weights < 0] = np.inf + inds = nx.argmin(sum_weights, axis=0) + + levMed = nx.take_along_axis(cdf_diff_sorted, nx.reshape(inds, (1, -1)), 0) + + return nx.sum(delta * nx.abs(cdf_diff - levMed), axis=0) + + +def wasserstein_circle( + u_values, + v_values, + u_weights=None, + v_weights=None, + p=1, + Lm=10, + Lp=10, + tm=-1, + tp=1, + eps=1e-6, + require_sort=True, +): + r"""Computes the Wasserstein distance on the circle using either :ref:`[45] ` for p=1 or + the binary search algorithm proposed in :ref:`[44] ` otherwise. + Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`, + takes the value modulo 1. + If the values are on :math:`S^1\subset\mathbb{R}^2`, it requires to first find the coordinates + using e.g. the atan2 function. + + General loss returned: + + .. math:: + OT_{loss} = \inf_{\theta\in\mathbb{R}}\int_0^1 |cdf_u^{-1}(q) - (cdf_v-\theta)^{-1}(q)|^p\ \mathrm{d}q + + For p=1, [45] + + .. math:: + W_1(u,v) = \int_0^1 |F_u(t)-F_v(t)-LevMed(F_u-F_v)|\ \mathrm{d}t + + For values :math:`x=(x_1,x_2)\in S^1`, it is required to first get their coordinates with + + .. math:: + u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi} + + using e.g. ot.utils.get_coordinate_circle(x) + + The function runs on backend but tensorflow and jax are not supported. + + Parameters + ---------- + u_values : ndarray, shape (n, ...) + samples in the source domain (coordinates on [0,1[) + v_values : ndarray, shape (n, ...) + samples in the target domain (coordinates on [0,1[) + u_weights : ndarray, shape (n, ...), optional + samples weights in the source domain + v_weights : ndarray, shape (n, ...), optional + samples weights in the target domain + p : float, optional (default=1) + Power p used for computing the Wasserstein distance + Lm : int, optional + Lower bound dC. For p>1. + Lp : int, optional + Upper bound dC. For p>1. + tm: float, optional + Lower bound theta. For p>1. + tp: float, optional + Upper bound theta. For p>1. + eps: float, optional + Stopping condition. For p>1. + require_sort: bool, optional + If True, sort the values. + + Returns + ------- + loss: float/array-like, shape (...) + Batched cost associated to the optimal transportation + + Examples + -------- + >>> u = np.array([[0.2,0.5,0.8]])%1 + >>> v = np.array([[0.4,0.5,0.7]])%1 + >>> wasserstein_circle(u.T, v.T) + array([0.1]) + + + .. _references-wasserstein-circle: + References + ---------- + .. [44] Hundrieser, Shayan, Marcel Klatt, and Axel Munk. "The statistics of circular optimal transport." Directional Statistics for Innovative Applications: A Bicentennial Tribute to Florence Nightingale. Singapore: Springer Nature Singapore, 2022. 57-82. + .. [45] Delon, Julie, Julien Salomon, and Andrei Sobolevski. "Fast transport optimization for Monge costs on the circle." SIAM Journal on Applied Mathematics 70.7 (2010): 2239-2258. + """ + assert p >= 1, "The OT loss is only valid for p>=1, {p} was given".format(p=p) + + return binary_search_circle( + u_values, + v_values, + u_weights, + v_weights, + p=p, + Lm=Lm, + Lp=Lp, + tm=tm, + tp=tp, + eps=eps, + require_sort=require_sort, + ) + + +def semidiscrete_wasserstein2_unif_circle(u_values, u_weights=None): + r"""Computes the closed-form for the 2-Wasserstein distance between samples and a uniform distribution on :math:`S^1` + Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`, + takes the value modulo 1. + If the values are on :math:`S^1\subset\mathbb{R}^2`, it is required to first find the coordinates + using e.g. the atan2 function. + + .. math:: + W_2^2(\mu_n, \nu) = \sum_{i=1}^n \alpha_i x_i^2 - \left(\sum_{i=1}^n \alpha_i x_i\right)^2 + \sum_{i=1}^n \alpha_i x_i \left(1-\alpha_i-2\sum_{k=1}^{i-1}\alpha_k\right) + \frac{1}{12} + + where: + + - :math:`\nu=\mathrm{Unif}(S^1)` and :math:`\mu_n = \sum_{i=1}^n \alpha_i \delta_{x_i}` + + For values :math:`x=(x_1,x_2)\in S^1`, it is required to first get their coordinates with + + .. math:: + u = \frac{\pi + \mathrm{atan2}(-x_2,-x_1)}{2\pi}, + + using e.g. ot.utils.get_coordinate_circle(x). + + Parameters + ---------- + u_values : ndarray, shape (n, ...) + Samples + u_weights : ndarray, shape (n, ...), optional + samples weights in the source domain + + Returns + ------- + loss: float/array-like, shape (...) + Batched cost associated to the optimal transportation + + Examples + -------- + >>> x0 = np.array([[0], [0.2], [0.4]]) + >>> semidiscrete_wasserstein2_unif_circle(x0) + array([0.02111111]) + + References + ---------- + .. [46] Bonet, C., Berg, P., Courty, N., Septier, F., Drumetz, L., & Pham, M. T. (2023). Spherical sliced-wasserstein. International Conference on Learning Representations. + """ + + if u_weights is not None: + nx = get_backend(u_values, u_weights) + else: + nx = get_backend(u_values) + + n = u_values.shape[0] + + u_values = u_values % 1 + + if len(u_values.shape) == 1: + u_values = nx.reshape(u_values, (n, 1)) + + if u_weights is None: + u_weights = nx.full(u_values.shape, 1.0 / n, type_as=u_values) + elif u_weights.ndim != u_values.ndim: + u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1) + + u_values = nx.sort(u_values, 0) + u_cdf = nx.cumsum(u_weights, 0) + u_cdf = nx.zero_pad(u_cdf, [(1, 0), (0, 0)]) + + cpt1 = nx.sum(u_weights * u_values**2, axis=0) + u_mean = nx.sum(u_weights * u_values, axis=0) + + ns = 1 - u_weights - 2 * u_cdf[:-1] + cpt2 = nx.sum(u_values * u_weights * ns, axis=0) + + return cpt1 - u_mean**2 + cpt2 + 1 / 12 + + +def linear_circular_embedding(x, u_values, u_weights=None, require_sort=True): + r"""Returns the embedding :math:`\hat{\mu}(x)` of Linear Circular OT with reference + :math:`\eta=\mathrm{Unif}(S^1)` evaluated in :math:`x`. + + For any :math:`x\in [0,1[`, the embedding is given by (see :ref:`[78] `) + + .. math`` + \hat{\mu}(x) = F_{\mu}^{-1}\big(x - \int z\mathrm{d}\mu(z) + \frac12) - x. + + Parameters + ---------- + x : ndary, shape (m,) + Points in [0,1[ where to evaluate the embedding + u_values : ndarray, shape (n, ...) + samples in the source domain (coordinates on [0,1[) + u_weights : ndarray, shape (n, ...), optional + samples weights in the source domain + + Returns + ------- + embedding: ndarray of shape (m, ...) + Embedding evaluated at :math:`x` + + .. _references-lcot: + References + ---------- + .. [78] Martin, R. D., Medri, I., Bai, Y., Liu, X., Yan, K., Rohde, G. K., & Kolouri, S. (2024). LCOT: Linear Circular Optimal Transport. International Conference on Learning Representations. + """ + if u_weights is not None: + nx = get_backend(u_values, u_weights) + else: + nx = get_backend(u_values) + + n = u_values.shape[0] + u_values = u_values % 1 + + if len(u_values.shape) == 1: + u_values = nx.reshape(u_values, (n, 1)) + + if u_weights is None: + u_weights = nx.full(u_values.shape, 1.0 / n, type_as=u_values) + elif u_weights.ndim != u_values.ndim: + u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1) + + if require_sort: + u_sorter = nx.argsort(u_values, 0) + u_values = nx.take_along_axis(u_values, u_sorter, 0) + u_weights = nx.take_along_axis(u_weights, u_sorter, 0) + + u_cdf = nx.cumsum(u_weights, 0) + u_cdf = nx.zero_pad(u_cdf, [(1, 0), (0, 0)]) + + q_s = ( + x[:, None] - nx.sum(u_values * u_weights, axis=0)[None] + 0.5 + ) # shape (m, ...) + + u_quantiles = quantile_function(q_s % 1, u_cdf, u_values) + + return (u_quantiles - x[:, None]) % 1 + + +def linear_circular_ot(u_values, v_values=None, u_weights=None, v_weights=None): + r"""Computes the Linear Circular Optimal Transport distance from :ref:`[78] ` using :math:`\eta=\mathrm{Unif}(S^1)` + as reference measure. + Samples need to be in :math:`S^1\cong [0,1[`. If they are on :math:`\mathbb{R}`, + takes the value modulo 1. + If the values are on :math:`S^1\subset\mathbb{R}^2`, it is required to first find the coordinates + using e.g. the atan2 function. + + General loss returned: + + .. math:: + \mathrm{LCOT}_2^2(\mu, \nu) = \int_0^1 d_{S^1}\big(\hat{\mu}(t), \hat{\nu}(t)\big)^2\ \mathrm{d}t + + where :math:`\hat{\mu}(x)=F_{\mu}^{-1}(x-\int z\mathrm{d}\mu(z)+\frac12) - x` for all :math:`x\in [0,1[`, + and :math:`d_{S^1}(x,y)=\min(|x-y|, 1-|x-y|)` for :math:`x,y\in [0,1[`. + + Parameters + ---------- + u_values : ndarray, shape (n, ...) + samples in the source domain (coordinates on [0,1[) + v_values : ndarray, shape (n, ...), optional + samples in the target domain (coordinates on [0,1[), if None, compute distance against uniform distribution + u_weights : ndarray, shape (n, ...), optional + samples weights in the source domain + v_weights : ndarray, shape (n, ...), optional + samples weights in the target domain + + Returns + ------- + loss: float/array-like, shape (...) + Batched cost associated to the linear optimal transportation + + Examples + -------- + >>> u = np.array([[0.2,0.5,0.8]])%1 + >>> v = np.array([[0.4,0.5,0.7]])%1 + >>> linear_circular_ot(u.T, v.T) + array([0.0127]) + + + .. _references-lcot: + References + ---------- + .. [78] Martin, R. D., Medri, I., Bai, Y., Liu, X., Yan, K., Rohde, G. K., & Kolouri, S. (2024). LCOT: Linear Circular Optimal Transport. International Conference on Learning Representations. + """ + if u_weights is not None: + nx = get_backend(u_values, u_weights) + else: + nx = get_backend(u_values) + + n = u_values.shape[0] + u_values = u_values % 1 + + if len(u_values.shape) == 1: + u_values = nx.reshape(u_values, (n, 1)) + + if u_weights is None: + u_weights = nx.full(u_values.shape, 1.0 / n, type_as=u_values) + elif u_weights.ndim != u_values.ndim: + u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1) + + unif_s1 = nx.linspace(0, 1, 101, type_as=u_values)[:-1] + + emb_u = linear_circular_embedding(unif_s1, u_values, u_weights) + + if v_values is None: + dist_u = nx.minimum(nx.abs(emb_u), 1 - nx.abs(emb_u)) + return nx.mean(dist_u**2, axis=0) + else: + m = v_values.shape[0] + if len(v_values.shape) == 1: + v_values = nx.reshape(v_values, (m, 1)) + + if u_values.shape[1] != v_values.shape[1]: + raise ValueError( + "u and v must have the same number of batchs {} and {} respectively given".format( + u_values.shape[1], v_values.shape[1] + ) + ) + + emb_v = linear_circular_embedding(unif_s1, v_values, v_weights) + + dist_uv = nx.minimum(nx.abs(emb_u - emb_v), 1 - nx.abs(emb_u - emb_v)) + return nx.mean(dist_uv**2, axis=0) diff --git a/ot/sliced.py b/ot/sliced.py index 81d0bd4a3..636432c2d 100644 --- a/ot/sliced.py +++ b/ot/sliced.py @@ -6,6 +6,7 @@ # Author: Adrien Corenflos # Nicolas Courty # Rémi Flamary +# Clément Bonet # # License: MIT License diff --git a/ot/unbalanced/__init__.py b/ot/unbalanced/__init__.py index 771452954..b7a526182 100644 --- a/ot/unbalanced/__init__.py +++ b/ot/unbalanced/__init__.py @@ -24,6 +24,10 @@ from ._lbfgs import lbfgsb_unbalanced, lbfgsb_unbalanced2 +from ._solver_1d import uot_1d + +from ._sliced import sliced_unbalanced_ot, unbalanced_sliced_ot + __all__ = [ "sinkhorn_knopp_unbalanced", "sinkhorn_unbalanced", @@ -38,4 +42,7 @@ "_get_loss_unbalanced", "lbfgsb_unbalanced", "lbfgsb_unbalanced2", + "uot_1d", + "sliced_unbalanced_ot", + "unbalanced_sliced_ot", ] diff --git a/ot/unbalanced/_sliced.py b/ot/unbalanced/_sliced.py new file mode 100644 index 000000000..f8ab86601 --- /dev/null +++ b/ot/unbalanced/_sliced.py @@ -0,0 +1,428 @@ +# -*- coding: utf-8 -*- +""" +Sliced Unbalanced OT solvers +""" + +# Author: Clément Bonet +# +# License: MIT License + +from ..backend import get_backend +from ..utils import get_parameter_pair, list_to_array +from ..sliced import get_random_projections +from ._solver_1d import rescale_potentials, uot_1d +from ..lp.solver_1d import emd_1d_dual_backprop, wasserstein_1d + + +def sliced_unbalanced_ot( + X_s, + X_t, + reg_m, + a=None, + b=None, + n_projections=50, + p=2, + projections=None, + seed=None, + numItermax=10, + log=False, +): + r""" + Compute the Sliced Unbalanced Optimal Transport (SUOT) between two empirical distributions. + The 1D UOT problem is computed with KL regularization and solved with a Frank-Wolfe algorithm in the dual, see :ref:`[82] `. + + The Sliced Unbalanced Optimal Transport (SUOT) is defined as + + .. math:: + \mathrm{SUOT}(\mu, \nu) = \int_{S^{d-1}} \mathrm{UOT}(P^\theta_\#\mu, P^\theta_\#\nu)\ \mathrm{d}\lambda(\theta) + + with :math:`P^\theta(x)=\langle x,\theta\rangle` and :math:`\lambda` the uniform distribution on the unit sphere. + + This function only works in pytorch or jax (but is not maintained in jax). + + Parameters + ---------- + X_s : ndarray, shape (n_samples_a, dim) + samples in the source domain + X_t : ndarray, shape (n_samples_b, dim) + samples in the target domain + reg_m: float or indexable object of length 1 or 2 + Marginal relaxation term. + If :math:`\mathrm{reg_{m}}` is a scalar or an indexable object of length 1, + then the same :math:`\mathrm{reg_{m}}` is applied to both marginal relaxations. + The balanced OT can be recovered using :math:`\mathrm{reg_{m}}=float("inf")`. + For semi-relaxed case, use either + :math:`\mathrm{reg_{m}}=(float("inf"), scalar)` or + :math:`\mathrm{reg_{m}}=(scalar, float("inf"))`. + If :math:`\mathrm{reg_{m}}` is an array, + it must have the same backend as input arrays `(a, b, M)`. + a : ndarray, shape (n_samples_a,), optional + samples weights in the source domain + b : ndarray, shape (n_samples_b,), optional + samples weights in the target domain + n_projections : int, optional + Number of projections used for the Monte-Carlo approximation + p: float, optional, by default =2 + Power p used for computing the sliced Wasserstein + projections: shape (dim, n_projections), optional + Projection matrix (n_projections and seed are not used in this case) + seed: int or RandomState or None, optional + Seed used for random number generator + numItermax: int, optional + log: bool, optional + if True, returns the projections used and their associated UOTs and reweighted marginals. + + Returns + ------- + loss: float/array-like, shape (...) + SUOT + log: dict, optional + If `log` is True, then returns a dictionary containing the projection directions used, the projected UOTs, and reweighted marginals on each slices. + + + .. _references-uot: + References + ---------- + .. [82] Bonet, C., Nadjahi, K., Séjourné, T., Fatras, K., & Courty, N. (2025). + Slicing Unbalanced Optimal Transport. Transactions on Machine Learning Research. + """ + X_s, X_t = list_to_array(X_s, X_t) + + if a is not None and b is not None and projections is None: + nx = get_backend(X_s, X_t, a, b) + elif a is not None and b is not None and projections is not None: + nx = get_backend(X_s, X_t, a, b, projections) + elif a is None and b is None and projections is not None: + nx = get_backend(X_s, X_t, projections) + else: + nx = get_backend(X_s, X_t) + + assert nx.__name__ in ["torch", "jax"], "Function only valid in torch and jax" + + n = X_s.shape[0] + m = X_t.shape[0] + + if X_s.shape[1] != X_t.shape[1]: + raise ValueError( + "X_s and X_t must have the same number of dimensions {} and {} respectively given".format( + X_s.shape[1], X_t.shape[1] + ) + ) + + if a is None: + a = nx.full(n, 1 / n, type_as=X_s) + if b is None: + b = nx.full(m, 1 / m, type_as=X_s) + + d = X_s.shape[1] + + if projections is None: + projections = get_random_projections( + d, n_projections, seed, backend=nx, type_as=X_s + ) + else: + n_projections = projections.shape[1] + + X_s_projections = nx.dot(X_s, projections) # shape (n, n_projs) + X_t_projections = nx.dot(X_t, projections) + + # Compute UOT on each slice + a_reweighted, b_reweighted, projected_uot = uot_1d( + X_s_projections, + X_t_projections, + reg_m, + a, + b, + p, + require_sort=True, + numItermax=numItermax, + returnCost="total", + ) + + res = nx.mean(projected_uot) + + if log: + dico = { + "projections": projections, + "projected_uots": projected_uot, + "a_reweighted": a_reweighted, + "b_reweighted": b_reweighted, + } + return res, dico + + return res + + +def get_reweighted_marginals_usot( + f, g, a, b, reg_m1, reg_m2, X_s_sorter, X_t_sorter, nx +): + r""" + One step of the FW algorithm for the Unbalanced Sliced OT problem, see Algorithm 1 and 3 in :ref:`[82] `. + This function computes the reweighted marginals given the current potentials and the translation term. + It returns the current potentials, and the reweighted marginals (normalized by the mass, so that they sum to 1). + + Parameters + ---------- + f: array-like shape (n, ...) + Current potential on the source samples + g: array-like shape (m, ...) + Current potential on the target samples + a: array-like shape (n, ...) + Current weights on the source samples + b: array-like shape (m, ...) + Current weights on the target samples + reg_m1: float + Marginal relaxation term for the source distribution + reg_m2: float + Marginal relaxation term for the target distribution + X_s_sorter: array-like shape (n_projs, n) + Sorter for the projected source samples + X_t_sorter: array-like shape (n_projs, m) + Sorter for the projected target samples + nx: module + backend module + + Returns + ------- + f: array-like shape (n, ...) + Current potential on the source samples + g: array-like shape (m, ...) + Current potential on the target samples + a_reweighted: array-like shape (n, ...) + Reweighted weights on the source samples (normalized by the mass) + b_reweighted: array-like shape (m, ...) + Reweighted weights on the target samples (normalized by the mass) + full_mass: array-like shape (...) + Mass of the reweighted measures + + + .. _references-uot: + References + ---------- + [82] Bonet, C., Nadjahi, K., Séjourné, T., Fatras, K., & Courty, N. (2025). + Slicing Unbalanced Optimal Transport. Transactions on Machine Learning Research. + """ + # translate potentials + transl = rescale_potentials(f, g, a, b, reg_m1, reg_m2, nx) + + f = f + transl + g = g - transl + + # update measures + if reg_m1 != float("inf"): + a_reweighted = (a * nx.exp(-f / reg_m1))[..., X_s_sorter] + else: + a_reweighted = a[..., X_s_sorter] + + if reg_m2 != float("inf"): + b_reweighted = (b * nx.exp(-g / reg_m2))[..., X_t_sorter] + else: + b_reweighted = b[..., X_t_sorter] + + full_mass = nx.sum(a_reweighted, axis=1) + + # normalize the weights for compatibility with wasserstein_1d + a_reweighted = a_reweighted / nx.sum(a_reweighted, axis=1, keepdims=True) + b_reweighted = b_reweighted / nx.sum(b_reweighted, axis=1, keepdims=True) + + return f, g, a_reweighted, b_reweighted, full_mass + + +def unbalanced_sliced_ot( + X_s, + X_t, + reg_m, + a=None, + b=None, + n_projections=50, + p=2, + projections=None, + seed=None, + numItermax=10, + log=False, +): + r""" + Compute the Unbalanced Sliced Optimal Transpot (USOT) between two empirical distributions. + The USOT problem is computed with KL regularization and solved with a Frank-Wolfe algorithm in the dual, see :ref:`[82] `. + + The Unbalanced SOT problem reads as + + .. math:: + \mathrm{USOT}(\mu, \nu) = \inf_{\pi_1,\pi_2} \mathrm{SW}_2^2(\pi_1, \pi_2) + \lambda_1 \mathrm{KL}(\pi_1||\mu) + \lambda_2 \mathrm{KL}(\pi_2||\nu). + + This function only works in pytorch or jax (but is not maintained in jax). + + Parameters + ---------- + X_s : ndarray, shape (n_samples_a, dim) + samples in the source domain + X_t : ndarray, shape (n_samples_b, dim) + samples in the target domain + reg_m: float or indexable object of length 1 or 2 + Marginal relaxation term. + If :math:`\mathrm{reg_{m}}` is a scalar or an indexable object of length 1, + then the same :math:`\mathrm{reg_{m}}` is applied to both marginal relaxations. + The balanced OT can be recovered using :math:`\mathrm{reg_{m}}=float("inf")`. + For semi-relaxed case, use either + :math:`\mathrm{reg_{m}}=(float("inf"), scalar)` or + :math:`\mathrm{reg_{m}}=(scalar, float("inf"))`. + If :math:`\mathrm{reg_{m}}` is an array, + it must have the same backend as input arrays `(a, b, M)`. + a : ndarray, shape (n_samples_a,), optional + samples weights in the source domain + b : ndarray, shape (n_samples_b,), optional + samples weights in the target domain + n_projections : int, optional + Number of projections used for the Monte-Carlo approximation + p: float, optional, by default =2 + Power p used for computing the sliced Wasserstein + projections: shape (dim, n_projections), optional + Projection matrix (n_projections and seed are not used in this case) + seed: int or RandomState or None, optional + Seed used for random number generator + numItermax: int, optional + log: bool, optional + if True, sliced_wasserstein_distance returns the projections used and their associated EMD. + + Returns + ------- + a_reweighted: array-like shape (n, ...) + First marginal reweighted + b_reweighted: array-like shape (m, ...) + Second marginal reweighted + loss: float/array-like, shape (...) + USOT + log: dict, optional + If `log` is True, then returns a dictionary containing the projection directions used, the 1D OT losses, the SOT loss and the full mass of reweighted marginals. + + + .. _references-uot: + References + ---------- + .. [82] Bonet, C., Nadjahi, K., Séjourné, T., Fatras, K., & Courty, N. (2025). + Slicing Unbalanced Optimal Transport. Transactions on Machine Learning Research. + """ + X_s, X_t = list_to_array(X_s, X_t) + + if a is not None and b is not None and projections is None: + nx = get_backend(X_s, X_t, a, b) + elif a is not None and b is not None and projections is not None: + nx = get_backend(X_s, X_t, a, b, projections) + elif a is None and b is None and projections is not None: + nx = get_backend(X_s, X_t, projections) + else: + nx = get_backend(X_s, X_t) + + assert nx.__name__ in ["torch", "jax"], "Function only valid in torch and jax" + + reg_m1, reg_m2 = get_parameter_pair(reg_m) + + n = X_s.shape[0] + m = X_t.shape[0] + + if X_s.shape[1] != X_t.shape[1]: + raise ValueError( + "X_s and X_t must have the same number of dimensions {} and {} respectively given".format( + X_s.shape[1], X_t.shape[1] + ) + ) + + if a is None: + a = nx.full(n, 1 / n, type_as=X_s) + if b is None: + b = nx.full(m, 1 / m, type_as=X_s) + + d = X_s.shape[1] + + if projections is None: + projections = get_random_projections( + d, n_projections, seed, backend=nx, type_as=X_s + ) + else: + n_projections = projections.shape[1] + + # Compute projections of the samples, and sort them for later use in the FW algorithm + X_s_projections = nx.dot(X_s, projections).T # shape (n_projs, n) + X_t_projections = nx.dot(X_t, projections).T + + X_s_sorter = nx.argsort(X_s_projections, -1) + X_s_rev_sorter = nx.argsort(X_s_sorter, -1) + X_s_sorted = nx.take_along_axis(X_s_projections, X_s_sorter, -1) + + X_t_sorter = nx.argsort(X_t_projections, -1) + X_t_rev_sorter = nx.argsort(X_t_sorter, -1) + X_t_sorted = nx.take_along_axis(X_t_projections, X_t_sorter, -1) + + # Initialize potentials - WARNING: They correspond to non-sorted samples + f = nx.zeros(a.shape, type_as=a) + g = nx.zeros(b.shape, type_as=b) + + for i in range(numItermax): + f, g, a_reweighted, b_reweighted, _ = get_reweighted_marginals_usot( + f, g, a, b, reg_m1, reg_m2, X_s_sorter, X_t_sorter, nx + ) + + fd, gd, _ = emd_1d_dual_backprop( + X_s_sorted.T, + X_t_sorted.T, + u_weights=a_reweighted.T, + v_weights=b_reweighted.T, + p=p, + require_sort=False, + ) + fd, gd = fd.T, gd.T + + # default step for FW + t = 2.0 / (2.0 + i) + + f = f + t * (nx.mean(nx.take_along_axis(fd, X_s_rev_sorter, 1), axis=0) - f) + g = g + t * (nx.mean(nx.take_along_axis(gd, X_t_rev_sorter, 1), axis=0) - g) + + f, g, a_reweighted, b_reweighted, full_mass = get_reweighted_marginals_usot( + f, g, a, b, reg_m1, reg_m2, X_s_sorter, X_t_sorter, nx + ) + + ot_loss = wasserstein_1d( + X_s_sorted.T, + X_t_sorted.T, + u_weights=a_reweighted.T, + v_weights=b_reweighted.T, + p=p, + require_sort=False, + ) + + sot_loss = nx.mean(ot_loss * full_mass) + + if reg_m1 != float("inf"): + a_reweighted = a * nx.exp(-f / reg_m1) + else: + a_reweighted = a + + if reg_m2 != float("inf"): + b_reweighted = b * nx.exp(-g / reg_m2) + else: + b_reweighted = b + + if reg_m1 == float("inf") and reg_m2 == float("inf"): + uot_loss = sot_loss + elif reg_m1 == float("inf"): + uot_loss = sot_loss + reg_m2 * nx.kl_div(b_reweighted, b, mass=True) + elif reg_m2 == float("inf"): + uot_loss = sot_loss + reg_m1 * nx.kl_div(a_reweighted, a, mass=True) + else: + uot_loss = ( + sot_loss + + reg_m1 * nx.kl_div(a_reweighted, a, mass=True) + + reg_m2 * nx.kl_div(b_reweighted, b, mass=True) + ) + + if log: + dico = { + "projections": projections, + "sot_loss": sot_loss, + "1d_losses": ot_loss, + "full_mass": full_mass, + } + return a_reweighted, b_reweighted, uot_loss, dico + + return a_reweighted, b_reweighted, uot_loss diff --git a/ot/unbalanced/_solver_1d.py b/ot/unbalanced/_solver_1d.py new file mode 100644 index 000000000..a9962516d --- /dev/null +++ b/ot/unbalanced/_solver_1d.py @@ -0,0 +1,315 @@ +# -*- coding: utf-8 -*- +""" +1D Unbalanced OT solvers +""" + +# Author: Clément Bonet +# +# License: MIT License + +from ..backend import get_backend +from ..utils import get_parameter_pair +from ..lp.solver_1d import emd_1d_dual_backprop, wasserstein_1d + + +def rescale_potentials(f, g, a, b, rho1, rho2, nx): + r""" + Find the optimal :math: `\lambda` in the translation invariant dual of UOT + with KL regularization and returns it, see Proposition 2 in :ref:`[73] `. + + Parameters + ---------- + f: array-like, shape (n, ...) + first dual potential + g: array-like, shape (m, ...) + second dual potential + a: array-like, shape (n, ...) + weights of the first empirical distribution + b: array-like, shape (m, ...) + weights of the second empirical distribution + rho1: float + Marginal relaxation term for the first marginal + rho2: float + Marginal relaxation term for the second marginal + nx: module + backend module + + Returns + ------- + transl: array-like, shape (...) + optimal translation + + .. _references-uot: + References + ---------- + .. [73] Séjourné, T., Vialard, F. X., & Peyré, G. (2022). + Faster unbalanced optimal transport: Translation invariant sinkhorn and 1-d frank-wolfe. + In International Conference on Artificial Intelligence and Statistics (pp. 4995-5021). PMLR. + """ + if rho1 == float("inf") and rho2 == float("inf"): + return nx.zeros(shape=nx.sum(f, axis=0).shape, type_as=f) + + elif rho1 == float("inf"): + tau = rho2 + denom = nx.logsumexp(-g / rho2 + nx.log(b), axis=0) + num = nx.log(nx.sum(a, axis=0)) + + elif rho2 == float("inf"): + tau = rho1 + num = nx.logsumexp(-f / rho1 + nx.log(a), axis=0) + denom = nx.log(nx.sum(b, axis=0)) + + else: + tau = (rho1 * rho2) / (rho1 + rho2) + num = nx.logsumexp(-f / rho1 + nx.log(a), axis=0) + denom = nx.logsumexp(-g / rho2 + nx.log(b), axis=0) + + transl = tau * (num - denom) + + return transl + + +def get_reweighted_marginal_uot( + f, g, u_weights_sorted, v_weights_sorted, reg_m1, reg_m2, nx +): + r""" + One step of the FW algorithm for the 1D UOT problem with KL regularization. + This function computes the reweighted marginals given the current dual potentials. + It returns the current potentials, and the reweighted marginals (normalized by the mass so that they sum to 1). + + Parameters + ---------- + f: array-like, shape (n, ...) + first dual potential + g: array-like, shape (m, ...) + second dual potential + u_weights_sorted: array-like, shape (n, ...) + weights of the first empirical distribution, sorted w.r.t. the support + v_weights_sorted: array-like, shape (m, ...) + weights of the second empirical distribution, sorted w.r.t. the support + reg_m1: float + Marginal relaxation term for the first marginal + reg_m2: float + Marginal relaxation term for the second marginal + nx: module + backend module + + Returns + ------- + f: array-like, shape (n, ...) + first dual potential + g: array-like, shape (m, ...) + second dual potential + u_rescaled: array-like, shape (n, ...) + reweighted first marginal, normalized by the mass + v_rescaled: array-like, shape (m, ...) + reweighted second marginal, normalized by the mass + full_mass: array-like, shape (...) + mass of the reweighted marginals + """ + transl = rescale_potentials( + f, g, u_weights_sorted, v_weights_sorted, reg_m1, reg_m2, nx + ) + + f = f + transl[None] + g = g - transl[None] + + if reg_m1 != float("inf"): + u_reweighted = u_weights_sorted * nx.exp(-f / reg_m1) + else: + u_reweighted = u_weights_sorted + + if reg_m2 != float("inf"): + v_reweighted = v_weights_sorted * nx.exp(-g / reg_m2) + else: + v_reweighted = v_weights_sorted + + full_mass = nx.sum(u_reweighted, axis=0) + + # Normalize weights + u_rescaled = u_reweighted / nx.sum(u_reweighted, axis=0, keepdims=True) + v_rescaled = v_reweighted / nx.sum(v_reweighted, axis=0, keepdims=True) + + return f, g, u_rescaled, v_rescaled, full_mass + + +def uot_1d( + u_values, + v_values, + reg_m, + u_weights=None, + v_weights=None, + p=2, + require_sort=True, + numItermax=10, + returnCost="linear", + log=False, +): + r""" + Solves the 1D unbalanced OT problem with KL regularization. + The function implements the Frank-Wolfe algorithm to solve the dual problem, + as proposed in :ref:`[73] `. + + The unbalanced OT problem reads + + .. math:: + \mathrm{UOT}(\mu,\nu) = \min_{\gamma \in \mathcal{M}_{+}(\mathbb{R}\times\mathbb{R})} W_2^2(\pi^1_\#\gamma,\pi^2_\#\gamma) + \mathrm{reg_{m}}_1 \mathrm{KL}(\pi^1_\#\gamma|\mu) + \mathrm{reg_{m}}_2 \mathrm{KL}(\pi^2_\#\gamma|\nu). + + ` + + This function only works in pytorch or jax (but is not maintained in jax). + + Parameters + ---------- + u_values: array-like, shape (n, ...) + locations of the first empirical distribution + v_values: array-like, shape (m, ...) + locations of the second empirical distribution + reg_m: float or indexable object of length 1 or 2 + Marginal relaxation term. + If :math:`\mathrm{reg_{m}}` is a scalar or an indexable object of length 1, + then the same :math:`\mathrm{reg_{m}}` is applied to both marginal relaxations. + The balanced OT can be recovered using :math:`\mathrm{reg_{m}}=float("inf")`. + For semi-relaxed case, use either + :math:`\mathrm{reg_{m}}=(float("inf"), scalar)` or + :math:`\mathrm{reg_{m}}=(scalar, float("inf"))`. + If :math:`\mathrm{reg_{m}}` is an array, + it must have the same backend as inxut arrays `(a, b)`. + u_weights: array-like, shape (n, ...), optional + weights of the first empirical distribution, if None then uniform weights are used + v_weights: array-like, shape (m, ...), optional + weights of the second empirical distribution, if None then uniform weights are used + p: int, optional + order of the ground metric used, should be at least 1, default is 2 + require_sort: bool, optional + sort the distributions atoms locations, if False we will consider they have been sorted prior to being passed to + the function, default is True + numItermax: int, optional + returnCost: string, optional (default = "linear") + If `returnCost` = "linear", then return the linear part of the unbalanced OT loss. + If `returnCost` = "total", then return the total unbalanced OT loss. + log: bool, optional + + Returns + ------- + u_reweighted: array-like shape (n, ...) + First marginal reweighted + v_reweighted: array-like shape (m, ...) + Second marginal reweighted + loss: float/array-like, shape (...) + The batched 1D UOT + log: dict, optional + If `log` is True, then returns a dictionary containing the dual potentials, the total cost and the linear cost. + + + .. _references-uot: + References + --------- + .. [73] Séjourné, T., Vialard, F. X., & Peyré, G. (2022). + Faster unbalanced optimal transport: Translation invariant sinkhorn and 1-d frank-wolfe. + In International Conference on Artificial Intelligence and Statistics (pp. 4995-5021). PMLR. + """ + if u_weights is not None and v_weights is not None: + nx = get_backend(u_values, v_values, u_weights, v_weights) + else: + nx = get_backend(u_values, v_values) + + assert nx.__name__ in ["torch", "jax"], "Function only valid in torch and jax" + + reg_m1, reg_m2 = get_parameter_pair(reg_m) + + n = u_values.shape[0] + m = v_values.shape[0] + + # Init weights or broadcast if necessary + if u_weights is None: + u_weights = nx.full(u_values.shape, 1.0 / n, type_as=u_values) + elif u_weights.ndim != u_values.ndim: + u_weights = nx.repeat(u_weights[..., None], u_values.shape[-1], -1) + + if v_weights is None: + v_weights = nx.full(v_values.shape, 1.0 / m, type_as=v_values) + elif v_weights.ndim != v_values.ndim: + v_weights = nx.repeat(v_weights[..., None], v_values.shape[-1], -1) + + # Sort w.r.t. support if not already done + if require_sort: + u_sorter = nx.argsort(u_values, 0) + u_rev_sorter = nx.argsort(u_sorter, 0) + u_values_sorted = nx.take_along_axis(u_values, u_sorter, 0) + + v_sorter = nx.argsort(v_values, 0) + v_rev_sorter = nx.argsort(v_sorter, 0) + v_values_sorted = nx.take_along_axis(v_values, v_sorter, 0) + + u_weights_sorted = nx.take_along_axis(u_weights, u_sorter, 0) + v_weights_sorted = nx.take_along_axis(v_weights, v_sorter, 0) + + f = nx.zeros(u_weights.shape, type_as=u_weights) + fd = nx.zeros(u_weights.shape, type_as=u_weights) + g = nx.zeros(v_weights.shape, type_as=v_weights) + gd = nx.zeros(v_weights.shape, type_as=v_weights) + + for i in range(numItermax): + # FW steps + f, g, u_rescaled, v_rescaled, _ = get_reweighted_marginal_uot( + f, g, u_weights_sorted, v_weights_sorted, reg_m1, reg_m2, nx + ) + + fd, gd, loss = emd_1d_dual_backprop( + u_values_sorted, + v_values_sorted, + u_weights=u_rescaled, + v_weights=v_rescaled, + p=p, + require_sort=False, + ) + + t = 2.0 / (2.0 + i) + f = f + t * (fd - f) + g = g + t * (gd - g) + + f, g, u_rescaled, v_rescaled, full_mass = get_reweighted_marginal_uot( + f, g, u_weights_sorted, v_weights_sorted, reg_m1, reg_m2, nx + ) + + loss = wasserstein_1d( + u_values_sorted, + v_values_sorted, + u_rescaled, + v_rescaled, + p=p, + require_sort=False, + ) + + if require_sort: + f = nx.take_along_axis(f, u_rev_sorter, 0) + g = nx.take_along_axis(g, v_rev_sorter, 0) + u_reweighted = nx.take_along_axis(u_rescaled, u_rev_sorter, 0) * full_mass + v_reweighted = nx.take_along_axis(v_rescaled, v_rev_sorter, 0) * full_mass + + # rescale OT loss + linear_loss = loss * full_mass + + if reg_m1 == float("inf") and reg_m2 == float("inf"): + uot_loss = linear_loss + elif reg_m1 == float("inf"): + uot_loss = linear_loss + reg_m2 * nx.kl_div(v_reweighted, v_weights, mass=True) + elif reg_m2 == float("inf"): + uot_loss = linear_loss + reg_m1 * nx.kl_div(u_reweighted, u_weights, mass=True) + else: + uot_loss = ( + linear_loss + + reg_m1 * nx.kl_div(u_reweighted, u_weights, mass=True, axis=0) + + reg_m2 * nx.kl_div(v_reweighted, v_weights, mass=True, axis=0) + ) + + if returnCost == "linear": + out_loss = linear_loss + elif returnCost == "total": + out_loss = uot_loss + + if log: + dico = {"f": f, "g": g, "total_cost": uot_loss, "linear_cost": linear_loss} + return u_reweighted, v_reweighted, out_loss, dico + return u_reweighted, v_reweighted, out_loss diff --git a/test/test_1d_solver.py b/test/test_1d_solver.py index c2f377469..71e67f7a1 100644 --- a/test/test_1d_solver.py +++ b/test/test_1d_solver.py @@ -2,6 +2,7 @@ # Author: Adrien Corenflos # Nicolas Courty +# Clément Bonet # # License: MIT License @@ -94,7 +95,7 @@ def test_wasserstein_1d_type_devices(nx): rho_v /= rho_v.sum() for tp in nx.__type_list__: - print(nx.dtype_device(tp)) + # print(nx.dtype_device(tp)) xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v, type_as=tp) @@ -178,7 +179,7 @@ def test_emd1d_type_devices(nx): rho_v /= rho_v.sum() for tp in nx.__type_list__: - print(nx.dtype_device(tp)) + # print(nx.dtype_device(tp)) xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v, type_as=tp) @@ -218,17 +219,13 @@ def test_emd1d_device_tf(): assert nx.dtype_device(emd)[1].startswith("GPU") -def test_wasserstein_1d_circle(): - # test binary_search_circle and wasserstein_circle give similar results as emd +def test_emd1d_dual_with_weights(nx): + # test emd1d_dual gives similar results as emd n = 20 m = 30 rng = np.random.RandomState(0) - u = rng.rand( - n, - ) - v = rng.rand( - m, - ) + u = rng.randn(n, 1) + v = rng.randn(m, 1) w_u = rng.uniform(0.0, 1.0, n) w_u = w_u / w_u.sum() @@ -236,207 +233,67 @@ def test_wasserstein_1d_circle(): w_v = rng.uniform(0.0, 1.0, m) w_v = w_v / w_v.sum() - M1 = np.minimum(np.abs(u[:, None] - v[None]), 1 - np.abs(u[:, None] - v[None])) - - wass1 = ot.emd2(w_u, w_v, M1) + u, v, w_u, w_v = nx.from_numpy(u, v, w_u, w_v) - wass1_bsc = ot.binary_search_circle(u, v, w_u, w_v, p=1) - w1_circle = ot.wasserstein_circle(u, v, w_u, w_v, p=1) + M = ot.dist(u, v, metric="sqeuclidean") + G, log = ot.emd(w_u, w_v, M, log=True) + wass = log["cost"] - M2 = M1**2 - wass2 = ot.emd2(w_u, w_v, M2) - wass2_bsc = ot.binary_search_circle(u, v, w_u, w_v, p=2) - w2_circle = ot.wasserstein_circle(u, v, w_u, w_v, p=2) + if nx.__name__ in ["torch", "jax"]: + f, g, wass1d = ot.emd_1d_dual_backprop(u, v, w_u, w_v, p=2) - # check loss is similar - np.testing.assert_allclose(wass1, wass1_bsc) - np.testing.assert_allclose(wass1, w1_circle, rtol=1e-2) - np.testing.assert_allclose(wass2, wass2_bsc) - np.testing.assert_allclose(wass2, w2_circle) + # check loss is similar + np.testing.assert_allclose(wass, wass1d) + np.testing.assert_allclose(wass, nx.sum(f[:, 0] * w_u) + nx.sum(g[:, 0] * w_v)) -@pytest.skip_backend("tf") -def test_wasserstein1d_circle_devices(nx): +@pytest.skip_backend("jax") # problem with jax on macOS +def test_emd1d_dual_backprop_batch(nx): rng = np.random.RandomState(0) - n = 10 - x = np.linspace(0, 1, n) + n = 100 + rho_u = np.abs(rng.randn(n)) rho_u /= rho_u.sum() rho_v = np.abs(rng.randn(n)) rho_v /= rho_v.sum() - for tp in nx.__type_list__: - print(nx.dtype_device(tp)) - - xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v, type_as=tp) - - w1 = ot.wasserstein_circle(xb, xb, rho_ub, rho_vb, p=1) - w2_bsc = ot.wasserstein_circle(xb, xb, rho_ub, rho_vb, p=2) - - nx.assert_same_dtype_device(xb, w1) - nx.assert_same_dtype_device(xb, w2_bsc) - - -def test_wasserstein_1d_unif_circle(): - # test semidiscrete_wasserstein2_unif_circle versus wasserstein_circle - n = 20 - m = 1000 - - rng = np.random.RandomState(0) - u = rng.rand( - n, - ) - v = rng.rand( - m, - ) + rho_ub, rho_vb = nx.from_numpy(rho_u, rho_v) - # w_u = rng.uniform(0., 1., n) - # w_u = w_u / w_u.sum() - - w_u = ot.utils.unif(n) - w_v = ot.utils.unif(m) - - M1 = np.minimum(np.abs(u[:, None] - v[None]), 1 - np.abs(u[:, None] - v[None])) - wass2 = ot.emd2(w_u, w_v, M1**2) - - wass2_circle = ot.wasserstein_circle(u, v, w_u, w_v, p=2, eps=1e-15) - wass2_unif_circle = ot.semidiscrete_wasserstein2_unif_circle(u, w_u) - - # check loss is similar - np.testing.assert_allclose(wass2, wass2_unif_circle, atol=1e-2) - np.testing.assert_allclose(wass2_circle, wass2_unif_circle, atol=1e-2) - - -def test_wasserstein1d_unif_circle_devices(nx): - rng = np.random.RandomState(0) - - n = 10 - x = np.linspace(0, 1, n) - rho_u = np.abs(rng.randn(n)) - rho_u /= rho_u.sum() - - for tp in nx.__type_list__: - print(nx.dtype_device(tp)) - - xb, rho_ub = nx.from_numpy(x, rho_u, type_as=tp) - - w2 = ot.semidiscrete_wasserstein2_unif_circle(xb, rho_ub) - - nx.assert_same_dtype_device(xb, w2) - - -def test_binary_search_circle_log(): - n = 20 - m = 30 - rng = np.random.RandomState(0) - u = rng.rand( - n, - ) - v = rng.rand( - m, - ) - - wass2_bsc, log = ot.binary_search_circle(u, v, p=2, log=True) - optimal_thetas = log["optimal_theta"] - - assert optimal_thetas.shape[0] == 1 + X = np.stack((np.linspace(0, 5, n), np.linspace(0, 5, n) * 10), -1) + Xb = nx.from_numpy(X) -def test_wasserstein_circle_bad_shape(): - n = 20 - m = 30 - rng = np.random.RandomState(0) - u = rng.rand(n, 2) - v = rng.rand(m, 1) + if nx.__name__ in ["torch", "jax"]: + f, g, res = ot.emd_1d_dual_backprop(Xb, Xb, rho_ub, rho_vb, p=2) - with pytest.raises(ValueError): - _ = ot.wasserstein_circle(u, v, p=2) + cost_dual = nx.sum(f * rho_ub[:, None], axis=0) + nx.sum( + g * rho_vb[:, None], axis=0 + ) - with pytest.raises(ValueError): - _ = ot.wasserstein_circle(u, v, p=1) + np.testing.assert_almost_equal(100 * res[0], res[1], decimal=4) + np.testing.assert_allclose(cost_dual, res) + else: + np.testing.assert_raises( + AssertionError, ot.emd_1d_dual_backprop, Xb, Xb, rho_ub, rho_vb, p=2 + ) -@pytest.skip_backend("tf") -def test_linear_circular_ot_devices(nx): +def test_emd1d_dual_type_devices(nx): rng = np.random.RandomState(0) n = 10 - x = np.linspace(0, 1, n) + x = np.linspace(0, 5, n) rho_u = np.abs(rng.randn(n)) rho_u /= rho_u.sum() rho_v = np.abs(rng.randn(n)) rho_v /= rho_v.sum() for tp in nx.__type_list__: - print(nx.dtype_device(tp)) - + # print(nx.dtype_device(tp)) xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v, type_as=tp) - - lcot = ot.linear_circular_ot(xb, xb, rho_ub, rho_vb) - - nx.assert_same_dtype_device(xb, lcot) - - -def test_linear_circular_ot_bad_shape(): - n = 20 - m = 30 - rng = np.random.RandomState(0) - u = rng.rand(n, 2) - v = rng.rand(m, 1) - - with pytest.raises(ValueError): - _ = ot.linear_circular_ot(u, v) - - -def test_linear_circular_ot_same_dist(): - n = 20 - rng = np.random.RandomState(0) - u = rng.rand(n) - - lcot = ot.linear_circular_ot(u, u) - np.testing.assert_almost_equal(lcot, 0.0) - - -def test_linear_circular_ot_different_dist(): - n = 20 - m = 30 - rng = np.random.RandomState(0) - u = rng.rand(n) - v = rng.rand(m) - - lcot = ot.linear_circular_ot(u, v) - assert lcot > 0.0 - - -def test_linear_circular_embedding_shape(): - n = 20 - rng = np.random.RandomState(0) - u = rng.rand(n, 2) - - ts = np.linspace(0, 1, 101)[:-1] - - emb = ot.lp.solver_1d.linear_circular_embedding(ts, u) - assert emb.shape == (100, 2) - - emb = ot.lp.solver_1d.linear_circular_embedding(ts, u[:, 0]) - assert emb.shape == (100, 1) - - -def test_linear_circular_ot_unif_circle(): - n = 20 - m = 1000 - - rng = np.random.RandomState(0) - u = rng.rand( - n, - ) - v = rng.rand( - m, - ) - - lcot = ot.linear_circular_ot(u, v) - lcot_unif = ot.linear_circular_ot(u) - - # check loss is similar - np.testing.assert_allclose(lcot, lcot_unif, atol=1e-2) + if nx.__name__ == "torch" or nx.__name__ == "jax": + f, g, res = ot.emd_1d_dual_backprop(xb, xb, rho_ub, rho_vb, p=1) + nx.assert_same_dtype_device(xb, res) + nx.assert_same_dtype_device(xb, f) + nx.assert_same_dtype_device(xb, g) diff --git a/test/test_backend.py b/test/test_backend.py index efd696ef0..50e52eb73 100644 --- a/test/test_backend.py +++ b/test/test_backend.py @@ -139,6 +139,7 @@ def test_empty_backend(): rnd = np.random.RandomState(0) M = rnd.randn(10, 3) v = rnd.randn(3) + inds = rnd.randint(10) nx = ot.backend.Backend() @@ -321,6 +322,10 @@ def test_empty_backend(): nx.slogdet(M) with pytest.raises(NotImplementedError): nx.unsqueeze(M, 0) + with pytest.raises(NotImplementedError): + nx.index_select(M, 0, inds) + with pytest.raises(NotImplementedError): + nx.nonzero(M) def test_func_backends(nx): @@ -753,6 +758,14 @@ def test_func_backends(nx): lst_b.append(np.array([s, logabsd])) lst_name.append("slogdet") + vec = nx.index_select(vb, 0, nx.from_numpy(np.array([0, 1]))) + lst_b.append(nx.to_numpy(vec)) + lst_name.append("index_select") + + vec = nx.nonzero(Mb) + lst_b.append(nx.to_numpy(vec)) + lst_name.append("nonzero") + assert not nx.array_equal(Mb, vb), "array_equal (shape)" assert nx.array_equal(Mb, Mb), "array_equal (elements) - expected true" assert not nx.array_equal( diff --git a/test/test_circle_solver.py b/test/test_circle_solver.py new file mode 100644 index 000000000..35097b1c0 --- /dev/null +++ b/test/test_circle_solver.py @@ -0,0 +1,234 @@ +"""Tests for module Circle Wasserstein solver""" + +# Author: Clément Bonet +# +# License: MIT License + +import numpy as np +import pytest + +import ot + + +def test_wasserstein_1d_circle(): + # test binary_search_circle and wasserstein_circle give similar results as emd + n = 20 + m = 30 + rng = np.random.RandomState(0) + u = rng.rand( + n, + ) + v = rng.rand( + m, + ) + + w_u = rng.uniform(0.0, 1.0, n) + w_u = w_u / w_u.sum() + + w_v = rng.uniform(0.0, 1.0, m) + w_v = w_v / w_v.sum() + + M1 = np.minimum(np.abs(u[:, None] - v[None]), 1 - np.abs(u[:, None] - v[None])) + + wass1 = ot.emd2(w_u, w_v, M1) + + wass1_bsc = ot.binary_search_circle(u, v, w_u, w_v, p=1) + w1_circle = ot.wasserstein_circle(u, v, w_u, w_v, p=1) + + M2 = M1**2 + wass2 = ot.emd2(w_u, w_v, M2) + wass2_bsc = ot.binary_search_circle(u, v, w_u, w_v, p=2) + w2_circle = ot.wasserstein_circle(u, v, w_u, w_v, p=2) + + # check loss is similar + np.testing.assert_allclose(wass1, wass1_bsc) + np.testing.assert_allclose(wass1, w1_circle, rtol=1e-2) + np.testing.assert_allclose(wass2, wass2_bsc) + np.testing.assert_allclose(wass2, w2_circle) + + +@pytest.skip_backend("tf") +def test_wasserstein1d_circle_devices(nx): + rng = np.random.RandomState(0) + + n = 10 + x = np.linspace(0, 1, n) + rho_u = np.abs(rng.randn(n)) + rho_u /= rho_u.sum() + rho_v = np.abs(rng.randn(n)) + rho_v /= rho_v.sum() + + for tp in nx.__type_list__: + # print(nx.dtype_device(tp)) + + xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v, type_as=tp) + + w1 = ot.wasserstein_circle(xb, xb, rho_ub, rho_vb, p=1) + w2_bsc = ot.wasserstein_circle(xb, xb, rho_ub, rho_vb, p=2) + + nx.assert_same_dtype_device(xb, w1) + nx.assert_same_dtype_device(xb, w2_bsc) + + +def test_wasserstein_1d_unif_circle(): + # test semidiscrete_wasserstein2_unif_circle versus wasserstein_circle + n = 20 + m = 1000 + + rng = np.random.RandomState(0) + u = rng.rand( + n, + ) + v = rng.rand( + m, + ) + + # w_u = rng.uniform(0., 1., n) + # w_u = w_u / w_u.sum() + + w_u = ot.utils.unif(n) + w_v = ot.utils.unif(m) + + M1 = np.minimum(np.abs(u[:, None] - v[None]), 1 - np.abs(u[:, None] - v[None])) + wass2 = ot.emd2(w_u, w_v, M1**2) + + wass2_circle = ot.wasserstein_circle(u, v, w_u, w_v, p=2, eps=1e-15) + wass2_unif_circle = ot.semidiscrete_wasserstein2_unif_circle(u, w_u) + + # check loss is similar + np.testing.assert_allclose(wass2, wass2_unif_circle, atol=1e-2) + np.testing.assert_allclose(wass2_circle, wass2_unif_circle, atol=1e-2) + + +def test_wasserstein1d_unif_circle_devices(nx): + rng = np.random.RandomState(0) + + n = 10 + x = np.linspace(0, 1, n) + rho_u = np.abs(rng.randn(n)) + rho_u /= rho_u.sum() + + for tp in nx.__type_list__: + # print(nx.dtype_device(tp)) + + xb, rho_ub = nx.from_numpy(x, rho_u, type_as=tp) + + w2 = ot.semidiscrete_wasserstein2_unif_circle(xb, rho_ub) + + nx.assert_same_dtype_device(xb, w2) + + +def test_binary_search_circle_log(): + n = 20 + m = 30 + rng = np.random.RandomState(0) + u = rng.rand( + n, + ) + v = rng.rand( + m, + ) + + wass2_bsc, log = ot.binary_search_circle(u, v, p=2, log=True) + optimal_thetas = log["optimal_theta"] + + assert optimal_thetas.shape[0] == 1 + + +def test_wasserstein_circle_bad_shape(): + n = 20 + m = 30 + rng = np.random.RandomState(0) + u = rng.rand(n, 2) + v = rng.rand(m, 1) + + with pytest.raises(ValueError): + _ = ot.wasserstein_circle(u, v, p=2) + + with pytest.raises(ValueError): + _ = ot.wasserstein_circle(u, v, p=1) + + +@pytest.skip_backend("tf") +def test_linear_circular_ot_devices(nx): + rng = np.random.RandomState(0) + + n = 10 + x = np.linspace(0, 1, n) + rho_u = np.abs(rng.randn(n)) + rho_u /= rho_u.sum() + rho_v = np.abs(rng.randn(n)) + rho_v /= rho_v.sum() + + for tp in nx.__type_list__: + print(nx.dtype_device(tp)) + + xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v, type_as=tp) + + lcot = ot.linear_circular_ot(xb, xb, rho_ub, rho_vb) + + nx.assert_same_dtype_device(xb, lcot) + + +def test_linear_circular_ot_bad_shape(): + n = 20 + m = 30 + rng = np.random.RandomState(0) + u = rng.rand(n, 2) + v = rng.rand(m, 1) + + with pytest.raises(ValueError): + _ = ot.linear_circular_ot(u, v) + + +def test_linear_circular_ot_same_dist(): + n = 20 + rng = np.random.RandomState(0) + u = rng.rand(n) + + lcot = ot.linear_circular_ot(u, u) + np.testing.assert_almost_equal(lcot, 0.0) + + +def test_linear_circular_ot_different_dist(): + n = 20 + m = 30 + rng = np.random.RandomState(0) + u = rng.rand(n) + v = rng.rand(m) + + lcot = ot.linear_circular_ot(u, v) + assert lcot > 0.0 + + +def test_linear_circular_embedding_shape(): + n = 20 + rng = np.random.RandomState(0) + u = rng.rand(n, 2) + + ts = np.linspace(0, 1, 101)[:-1] + + emb = ot.lp.solver_circle.linear_circular_embedding(ts, u) + assert emb.shape == (100, 2) + + emb = ot.lp.solver_circle.linear_circular_embedding(ts, u[:, 0]) + assert emb.shape == (100, 1) + + +def test_linear_circular_ot_unif_circle(): + n = 20 + m = 1000 + + rng = np.random.RandomState(0) + u = rng.rand( + n, + ) + v = rng.rand( + m, + ) + + lcot = ot.linear_circular_ot(u, v) + lcot_unif = ot.linear_circular_ot(u) + + # check loss is similar + np.testing.assert_allclose(lcot, lcot_unif, atol=1e-2) diff --git a/test/unbalanced/test_1d_solver.py b/test/unbalanced/test_1d_solver.py new file mode 100644 index 000000000..1f908404e --- /dev/null +++ b/test/unbalanced/test_1d_solver.py @@ -0,0 +1,270 @@ +"""Tests for module 1D Unbalanced OT""" + +# Author: Clément Bonet +# +# License: MIT License + +import itertools +import numpy as np +import ot +import pytest + + +def test_uot_1d(nx): + n_samples = 20 # nb samples + + rng = np.random.RandomState(42) + xs = rng.randn(n_samples, 1) + xt = rng.randn(n_samples, 1) + + a_np = ot.utils.unif(n_samples) + b_np = ot.utils.unif(n_samples) + + reg_m = 1.0 + + M = ot.dist(xs, xt) + a, b, M = nx.from_numpy(a_np, b_np, M) + xs, xt = nx.from_numpy(xs, xt) + + G, log = ot.unbalanced.mm_unbalanced(a, b, M, reg_m, div="kl", log=True) + loss_mm = log["cost"] + + if nx.__name__ in ["jax", "torch"]: + f, g, loss_1d = ot.unbalanced.uot_1d(xs, xt, reg_m, p=2) + np.testing.assert_allclose(loss_1d, loss_mm, atol=1e-2) + np.testing.assert_allclose(G.sum(0), g[:, 0], atol=1e-2) + np.testing.assert_allclose(G.sum(1), f[:, 0], atol=1e-2) + + +def test_uot_1d_convergence(nx): + n_samples = 20 # nb samples + + rng = np.random.RandomState(42) + xs = rng.randn(n_samples, 1) + xt = rng.randn(n_samples, 1) + xs, xt = nx.from_numpy(xs, xt) + + reg_m = 1000 + + # wass1d = ot.wasserstein_1d(xs, xt, p=2) + G_1d, log = ot.emd_1d(xs, xt, metric="sqeuclidean", log=True) + wass1d = log["cost"] + u_w1d, v_w1d = nx.sum(G_1d, 1), nx.sum(G_1d, 0) + + if nx.__name__ in ["jax", "torch"]: + u, v, loss_1d = ot.unbalanced.uot_1d(xs, xt, reg_m, p=2) + np.testing.assert_allclose(loss_1d, wass1d, atol=1e-2) + np.testing.assert_allclose(v_w1d, v[:, 0], atol=1e-2) + np.testing.assert_allclose(u_w1d, u[:, 0], atol=1e-2) + + +def test_uot_1d_batch(nx): + n_samples = 20 # nb samples + m_samples = 30 + + rng = np.random.RandomState(42) + xs = rng.randn(n_samples, 1) + xt = rng.randn(m_samples, 1) + xs = np.concatenate([xs, xs], axis=1) + xt = np.concatenate([xt, xt], axis=1) + + a_np = rng.uniform(0, 1, n_samples) # unbalanced + b_np = ot.utils.unif(m_samples) + + xs, xt, a, b = nx.from_numpy(xs, xt, a_np, b_np) + + reg_m = 1 + + if nx.__name__ in ["jax", "torch"]: + u1, v1, uot_1d = ot.unbalanced.uot_1d(xs[:, 0], xt[:, 0], reg_m, a, b, p=2) + u, v, loss_1d = ot.unbalanced.uot_1d(xs, xt, reg_m, a, b, p=2) + + np.testing.assert_allclose(loss_1d[0], loss_1d[1], atol=1e-5) + np.testing.assert_allclose(loss_1d[0], uot_1d, atol=1e-5) + + u1, v1, uot_1d = ot.unbalanced.uot_1d( + xs[:, 0], xt[:, 0], reg_m, a, b, p=2, returnCost="total" + ) + u, v, loss_1d = ot.unbalanced.uot_1d( + xs, xt, reg_m, a, b, p=2, returnCost="total" + ) + + np.testing.assert_allclose(loss_1d[0], loss_1d[1], atol=1e-5) + np.testing.assert_allclose(loss_1d[0], uot_1d, atol=1e-5) + + +def test_uot_1d_inf_reg_m_backprop(nx): + n_samples = 20 # nb samples + + rng = np.random.RandomState(42) + xs = rng.randn(n_samples, 1) + xt = rng.randn(n_samples, 1) + + a_np = ot.utils.unif(n_samples) + b_np = ot.utils.unif(n_samples) + + reg_m = float("inf") + + a, b = nx.from_numpy(a_np, b_np) + xs, xt = nx.from_numpy(xs, xt) + + if nx.__name__ in ["jax", "torch"]: + f_w1d, g_w1d, wass1d = ot.emd_1d_dual_backprop(xs, xt, a, b, p=2) + u, v, loss_1d, log = ot.unbalanced.uot_1d(xs, xt, reg_m, a, b, p=2, log=True) + + # Check right loss + np.testing.assert_allclose(loss_1d, wass1d) + + # Check right marginals + np.testing.assert_allclose(a, u[:, 0]) + np.testing.assert_allclose(b, v[:, 0]) + + # Check potentials + np.testing.assert_allclose(f_w1d, log["f"]) + np.testing.assert_allclose(g_w1d, log["g"]) + + +def test_semi_uot_1d_backprop(nx): + n_samples = 20 # nb samples + + rng = np.random.RandomState(42) + xs = rng.randn(n_samples, 1) + xt = rng.randn(n_samples, 1) + + a_np = ot.utils.unif(n_samples) + b_np = ot.utils.unif(n_samples) + + a, b = nx.from_numpy(a_np, b_np) + xs, xt = nx.from_numpy(xs, xt) + + reg_m = (float("inf"), 1.0) + + if nx.__name__ in ["jax", "torch"]: + u, v, loss_1d = ot.unbalanced.uot_1d(xs, xt, reg_m, p=2) + + # Check right marginals + np.testing.assert_allclose(a, u[:, 0]) + np.testing.assert_allclose(v[:, 0].sum(), 1) + + reg_m = (1.0, float("inf")) + + if nx.__name__ in ["jax", "torch"]: + u, v, loss_1d = ot.unbalanced.uot_1d(xs, xt, reg_m, p=2) + + # Check right marginals + np.testing.assert_allclose(b, v[:, 0]) + np.testing.assert_allclose(u[:, 0].sum(), 1) + + +@pytest.skip_backend("jax") # problem with jax on macOS +@pytest.mark.parametrize( + "reg_m", + itertools.product( + [1, float("inf")], + ), +) +def test_unbalanced_relaxation_parameters_backprop(nx, reg_m): + n = 100 + rng = np.random.RandomState(50) + + x = rng.randn(n, 2) + y = rng.randn(n, 2) + a = ot.utils.unif(n) + + # make dists unbalanced + b = rng.rand(n, 2) + + a, b, x, y = nx.from_numpy(a, b, x, y) + + reg_m = reg_m[0] + + # options for reg_m + full_list_reg_m = [reg_m, reg_m] + full_tuple_reg_m = (reg_m, reg_m) + tuple_reg_m, list_reg_m = (reg_m), [reg_m] + nx_reg_m = reg_m * nx.ones(1) + + list_options = [ + nx_reg_m, + full_tuple_reg_m, + tuple_reg_m, + full_list_reg_m, + list_reg_m, + ] + + if nx.__name__ in ["jax", "torch"]: + u, v, loss = ot.unbalanced.uot_1d(x, y, reg_m, u_weights=a, v_weights=b, p=2) + + for opt in list_options: + u, v, loss_opt = ot.unbalanced.uot_1d( + x, y, opt, u_weights=a, v_weights=b, p=2 + ) + + np.testing.assert_allclose( + nx.to_numpy(loss), nx.to_numpy(loss_opt), atol=1e-05 + ) + + +@pytest.skip_backend("jax") # problem with jax on macOS +@pytest.mark.parametrize( + "reg_m1, reg_m2", + itertools.product( + [1, float("inf")], + [1, float("inf")], + ), +) +def test_unbalanced_relaxation_parameters_pair_backprop(nx, reg_m1, reg_m2): + # test generalized sinkhorn for unbalanced OT + n = 100 + rng = np.random.RandomState(50) + + x = rng.randn(n, 2) + y = rng.randn(n, 2) + a = ot.utils.unif(n) + b = ot.utils.unif(n) + + a, b, x, y = nx.from_numpy(a, b, x, y) + + # options for reg_m + full_list_reg_m = [reg_m1, reg_m2] + full_tuple_reg_m = (reg_m1, reg_m2) + list_options = [full_tuple_reg_m, full_list_reg_m] + + if nx.__name__ in ["jax", "torch"]: + _, _, loss = ot.unbalanced.uot_1d( + x, y, (reg_m1, reg_m2), u_weights=a, v_weights=b, p=2 + ) + + for opt in list_options: + _, _, loss_opt = ot.unbalanced.uot_1d( + x, y, opt, u_weights=a, v_weights=b, p=2 + ) + + np.testing.assert_allclose( + nx.to_numpy(loss), nx.to_numpy(loss_opt), atol=1e-05 + ) + + +def test_uot_1d_type_devices_backprop(nx): + rng = np.random.RandomState(0) + + n = 10 + x = np.linspace(0, 5, n) + rho_u = np.abs(rng.randn(n)) + rho_u /= rho_u.sum() + rho_v = np.abs(rng.randn(n)) + rho_v /= rho_v.sum() + + reg_m = 1.0 + + xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v) + + if nx.__name__ in ["torch", "jax"]: + f, g, _ = ot.unbalanced.uot_1d(xb, xb, reg_m, rho_ub, rho_vb, p=2) + + nx.assert_same_dtype_device(xb, f) + nx.assert_same_dtype_device(xb, g) + else: + np.testing.assert_raises( + AssertionError, ot.unbalanced.uot_1d, xb, xb, reg_m, rho_ub, rho_vb, p=2 + ) diff --git a/test/unbalanced/test_sliced.py b/test/unbalanced/test_sliced.py new file mode 100644 index 000000000..16b9cbce9 --- /dev/null +++ b/test/unbalanced/test_sliced.py @@ -0,0 +1,387 @@ +"""Tests for module sliced Unbalanced OT""" + +# Author: Clément Bonet +# +# License: MIT License + +import itertools +import numpy as np +import ot +import pytest + + +def test_sliced_uot_same_dist(nx): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + u = ot.utils.unif(n) + + x, u = nx.from_numpy(x, u) + + if nx.__name__ in ["torch", "jax"]: + res = ot.sliced_unbalanced_ot(x, x, 1, u, u, 10, seed=42) + np.testing.assert_almost_equal(res, 0.0) + + _, _, res = ot.unbalanced_sliced_ot(x, x, 1, u, u, 10, seed=42) + np.testing.assert_almost_equal(res, 0.0) + + +def test_sliced_uot_bad_shapes(nx): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 2) + y = rng.randn(n, 4) + u = ot.utils.unif(n) + + if nx.__name__ in ["torch", "jax"]: + x, y, u = nx.from_numpy(x, y, u) + + with pytest.raises(ValueError): + _ = ot.sliced_unbalanced_ot(x, y, 1, u, u, 10, seed=42) + + with pytest.raises(ValueError): + _ = ot.unbalanced_sliced_ot(x, y, 1, u, u, 10, seed=42) + + +def test_sliced_uot_log(nx): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 4) + y = rng.randn(n, 4) + u = ot.utils.unif(n) + + if nx.__name__ in ["torch", "jax"]: + x, y, u = nx.from_numpy(x, y, u) + + res, log = ot.sliced_unbalanced_ot(x, y, 1, u, u, 10, p=1, seed=42, log=True) + assert len(log) == 4 + projections = log["projections"] + projected_uots = log["projected_uots"] + a_reweighted = log["a_reweighted"] + b_reweighted = log["b_reweighted"] + + assert projections.shape[1] == len(projected_uots) == 10 + + for emd in projected_uots: + assert emd > 0 + + assert res > 0 + assert a_reweighted.shape == b_reweighted.shape == (n, 10) + + +def test_unbalanced_sot_log(nx): + n = 100 + rng = np.random.RandomState(0) + + x = rng.randn(n, 4) + y = rng.randn(n, 4) + u = ot.utils.unif(n) + + if nx.__name__ in ["torch", "jax"]: + x, y, u = nx.from_numpy(x, y, u) + + f, g, res, log = ot.unbalanced_sliced_ot( + x, y, 1, u, u, 10, p=1, seed=42, log=True + ) + assert len(log) == 4 + + projections = log["projections"] + sot_loss = log["sot_loss"] + ot_loss = log["1d_losses"] + full_mass = log["full_mass"] + + assert projections.shape[1] == 10 + assert res > 0 + + assert f.shape == g.shape == u.shape + np.testing.assert_almost_equal(f.sum(), g.sum()) + np.testing.assert_equal(sot_loss, nx.mean(ot_loss * full_mass)) + + +def test_1d_sliced_equals_uot(nx): + n = 100 + m = 120 + rng = np.random.RandomState(42) + + x = rng.randn(n, 1) + y = rng.randn(m, 1) + + a = rng.uniform(0, 1, n) / 10 # unbalanced + u = ot.utils.unif(m) + + reg_m = 1 + + if nx.__name__ in ["torch", "jax"]: + x, y, a, u = nx.from_numpy(x, y, a, u) + + res, log = ot.sliced_unbalanced_ot( + x, y, reg_m, a, u, 10, seed=42, p=2, log=True + ) + a_exp, u_exp, expected = ot.uot_1d( + x.squeeze(), y.squeeze(), reg_m, a, u, returnCost="total", p=2 + ) + np.testing.assert_almost_equal(res, expected) + np.testing.assert_allclose(log["a_reweighted"][:, 0], a_exp) + np.testing.assert_allclose(log["b_reweighted"][:, 0], u_exp) + + f, g, res, log = ot.unbalanced_sliced_ot( + x, y, reg_m, a, u, 10, seed=42, p=2, log=True + ) + np.testing.assert_almost_equal(res, expected) + np.testing.assert_allclose(f, a_exp) + np.testing.assert_allclose(g, u_exp) + + +def test_sliced_projections(nx): + n = 100 + m = 120 + rng = np.random.RandomState(42) + + x = rng.randn(n, 4) + y = rng.randn(m, 4) + + a = rng.uniform(0, 1, n) / 10 # unbalanced + u = ot.utils.unif(m) + + reg_m = 1 + + if nx.__name__ in ["torch", "jax"]: + x, y, a, u = nx.from_numpy(x, y, a, u) + + res, log = ot.sliced_unbalanced_ot( + x, y, reg_m, a, u, 10, seed=42, p=2, log=True + ) + + projections = log["projections"] + + res2 = ot.sliced_unbalanced_ot(x, y, reg_m, a, u, 10, seed=42, p=2) + np.testing.assert_almost_equal(res, res2) + + res3 = ot.sliced_unbalanced_ot( + x, y, reg_m, a, u, 10, projections=projections, p=2 + ) + np.testing.assert_almost_equal(res, res3) + + _, _, res = ot.unbalanced_sliced_ot(x, y, reg_m, a, u, 10, seed=42, p=2) + + _, _, res2 = ot.unbalanced_sliced_ot( + x, y, reg_m, a, u, 10, projections=projections, p=2 + ) + np.testing.assert_almost_equal(res, res2) + + +def test_sliced_inf_reg_m(nx): + n_samples = 20 # nb samples + + rng = np.random.RandomState(42) + xs = rng.randn(n_samples, 4) + xt = rng.randn(n_samples, 4) + + a_np = ot.utils.unif(n_samples) + b_np = ot.utils.unif(n_samples) + + reg_m = float("inf") + + a, b = nx.from_numpy(a_np, b_np) + xs, xt = nx.from_numpy(xs, xt) + + if nx.__name__ in ["jax", "torch"]: + suot = ot.sliced_unbalanced_ot(xs, xt, reg_m, a, b, 10, seed=42, p=2) + + a_reweighted, b_reweighted, usot = ot.unbalanced_sliced_ot( + xs, xt, reg_m, a, b, 10, seed=42, p=2 + ) + + sw = ot.sliced_wasserstein_distance(xs, xt, n_projections=10, seed=42, p=2) + + # Check right loss + np.testing.assert_almost_equal(suot, sw**2) + np.testing.assert_almost_equal(usot, sw**2) + np.testing.assert_allclose(a_reweighted, a) + np.testing.assert_allclose(b_reweighted, b) + + +def test_semi_usot_1d(nx): + n_samples = 20 # nb samples + + rng = np.random.RandomState(42) + xs = rng.randn(n_samples, 1) + xt = rng.randn(n_samples, 1) + + a_np = ot.utils.unif(n_samples) + b_np = ot.utils.unif(n_samples) + + a, b = nx.from_numpy(a_np, b_np) + xs, xt = nx.from_numpy(xs, xt) + + reg_m = (float("inf"), 1.0) + + if nx.__name__ in ["jax", "torch"]: + a_reweighted, b_reweighted, usot = ot.unbalanced_sliced_ot( + xs, xt, reg_m, a, b, 10, seed=42, p=2 + ) + # Check right marginals + np.testing.assert_allclose(a, a_reweighted) + np.testing.assert_allclose(b_reweighted.sum(), 1) + + reg_m = (1.0, float("inf")) + + if nx.__name__ in ["jax", "torch"]: + a_reweighted, b_reweighted, usot = ot.unbalanced_sliced_ot( + xs, xt, reg_m, a, b, 10, seed=42, p=2 + ) + # Check right marginals + np.testing.assert_allclose(b, b_reweighted) + np.testing.assert_allclose(a_reweighted.sum(), 1) + + +@pytest.mark.parametrize( + "reg_m", + itertools.product( + [1, float("inf")], + ), +) +def test_sliced_unbalanced_relaxation_parameters(nx, reg_m): + n = 100 + rng = np.random.RandomState(50) + + x = rng.randn(n, 2) + a = ot.utils.unif(n) + + # make dists unbalanced + b = rng.rand(n) + + a, b, x = nx.from_numpy(a, b, x) + + reg_m = reg_m[0] + + # options for reg_m + full_list_reg_m = [reg_m, reg_m] + full_tuple_reg_m = (reg_m, reg_m) + tuple_reg_m, list_reg_m = (reg_m), [reg_m] + nx_reg_m = reg_m * nx.ones(1) + + list_options = [ + nx_reg_m, + full_tuple_reg_m, + tuple_reg_m, + full_list_reg_m, + list_reg_m, + ] + + if nx.__name__ in ["jax", "torch"]: + _, _, usot = ot.unbalanced_sliced_ot(x, x, reg_m, a, b, 10, seed=42, p=2) + + suot = ot.sliced_unbalanced_ot(x, x, reg_m, a, b, 10, seed=42, p=2) + + for opt in list_options: + _, _, usot_opt = ot.unbalanced_sliced_ot(x, x, opt, a, b, 10, seed=42, p=2) + np.testing.assert_allclose( + nx.to_numpy(usot), nx.to_numpy(usot_opt), atol=1e-05 + ) + + suot_opt = ot.sliced_unbalanced_ot(x, x, opt, a, b, 10, seed=42, p=2) + np.testing.assert_allclose( + nx.to_numpy(suot), nx.to_numpy(suot_opt), atol=1e-05 + ) + + +@pytest.mark.parametrize( + "reg_m1, reg_m2", + itertools.product( + [1, float("inf")], + [1, float("inf")], + ), +) +def test_sliced_unbalanced_relaxation_parameters_pair(nx, reg_m1, reg_m2): + n = 100 + rng = np.random.RandomState(50) + + x = rng.randn(n, 2) + a = ot.utils.unif(n) + + # make dists unbalanced + b = rng.rand(n) + + a, b, x = nx.from_numpy(a, b, x) + + # options for reg_m + full_list_reg_m = [reg_m1, reg_m2] + full_tuple_reg_m = (reg_m1, reg_m2) + list_options = [full_tuple_reg_m, full_list_reg_m] + + if nx.__name__ in ["jax", "torch"]: + _, _, usot = ot.unbalanced_sliced_ot( + x, x, (reg_m1, reg_m2), a, b, 10, seed=42, p=2 + ) + + suot = ot.sliced_unbalanced_ot(x, x, (reg_m1, reg_m2), a, b, 10, seed=42, p=2) + + for opt in list_options: + _, _, usot_opt = ot.unbalanced_sliced_ot(x, x, opt, a, b, 10, seed=42, p=2) + np.testing.assert_allclose( + nx.to_numpy(usot), nx.to_numpy(usot_opt), atol=1e-05 + ) + + suot_opt = ot.sliced_unbalanced_ot(x, x, opt, a, b, 10, seed=42, p=2) + np.testing.assert_allclose( + nx.to_numpy(suot), nx.to_numpy(suot_opt), atol=1e-05 + ) + + +def test_sliced_uot_type_devices(nx): + rng = np.random.RandomState(0) + + n = 10 + x = rng.randn(n, 2) + rho_u = np.abs(rng.randn(n)) + rho_u /= rho_u.sum() + rho_v = np.abs(rng.randn(n)) + rho_v /= rho_v.sum() + + reg_m = 1.0 + + xb, rho_ub, rho_vb = nx.from_numpy(x, rho_u, rho_v) + + if nx.__name__ in ["torch", "jax"]: + f, g, usot = ot.unbalanced_sliced_ot( + xb, xb, reg_m, rho_ub, rho_vb, 10, seed=42, p=2 + ) + + nx.assert_same_dtype_device(xb, f) + nx.assert_same_dtype_device(xb, g) + nx.assert_same_dtype_device(xb, usot) + else: + np.testing.assert_raises( + AssertionError, + ot.unbalanced_sliced_ot, + xb, + xb, + reg_m, + rho_ub, + rho_vb, + 10, + seed=42, + p=2, + ) + + if nx.__name__ in ["torch", "jax"]: + suot = ot.sliced_unbalanced_ot(xb, xb, reg_m, rho_ub, rho_vb, 10, seed=42, p=2) + + nx.assert_same_dtype_device(xb, suot) + else: + np.testing.assert_raises( + AssertionError, + ot.sliced_unbalanced_ot, + xb, + xb, + reg_m, + rho_ub, + rho_vb, + 10, + seed=42, + p=2, + )