Table of Contents
Fetching ...

ENOT: Expectile Regularization for Fast and Accurate Training of Neural Optimal Transport

Nazar Buzun, Maksim Bobrin, Dmitry V. Dylov

TL;DR

Proposed method, called Expectile-Regularised Neural Optimal Transport (ENOT), outperforms previous state-of-the-art approaches on the established Wasserstein-2 benchmark tasks by a large margin (up to a 3-fold improvement in quality and up to a 10-fold improvement in runtime).

Abstract

We present a new approach for Neural Optimal Transport (NOT) training procedure, capable of accurately and efficiently estimating optimal transportation plan via specific regularization on dual Kantorovich potentials. The main bottleneck of existing NOT solvers is associated with the procedure of finding a near-exact approximation of the conjugate operator (i.e., the c-transform), which is done either by optimizing over non-convex max-min objectives or by the computationally intensive fine-tuning of the initial approximated prediction. We resolve both issues by proposing a new, theoretically justified loss in the form of expectile regularisation which enforces binding conditions on the learning process of dual potentials. Such a regularization provides the upper bound estimation over the distribution of possible conjugate potentials and makes the learning stable, completely eliminating the need for additional extensive fine-tuning. Proposed method, called Expectile-Regularised Neural Optimal Transport (ENOT), outperforms previous state-of-the-art approaches on the established Wasserstein-2 benchmark tasks by a large margin (up to a 3-fold improvement in quality and up to a 10-fold improvement in runtime). Moreover, we showcase performance of ENOT for varying cost functions on different tasks such as image generation, showing robustness of proposed algorithm. OTT-JAX library includes our implementation of ENOT algorithm https://ott-jax.readthedocs.io/en/latest/tutorials/ENOT.html

ENOT: Expectile Regularization for Fast and Accurate Training of Neural Optimal Transport

TL;DR

Proposed method, called Expectile-Regularised Neural Optimal Transport (ENOT), outperforms previous state-of-the-art approaches on the established Wasserstein-2 benchmark tasks by a large margin (up to a 3-fold improvement in quality and up to a 10-fold improvement in runtime).

Abstract

We present a new approach for Neural Optimal Transport (NOT) training procedure, capable of accurately and efficiently estimating optimal transportation plan via specific regularization on dual Kantorovich potentials. The main bottleneck of existing NOT solvers is associated with the procedure of finding a near-exact approximation of the conjugate operator (i.e., the c-transform), which is done either by optimizing over non-convex max-min objectives or by the computationally intensive fine-tuning of the initial approximated prediction. We resolve both issues by proposing a new, theoretically justified loss in the form of expectile regularisation which enforces binding conditions on the learning process of dual potentials. Such a regularization provides the upper bound estimation over the distribution of possible conjugate potentials and makes the learning stable, completely eliminating the need for additional extensive fine-tuning. Proposed method, called Expectile-Regularised Neural Optimal Transport (ENOT), outperforms previous state-of-the-art approaches on the established Wasserstein-2 benchmark tasks by a large margin (up to a 3-fold improvement in quality and up to a 10-fold improvement in runtime). Moreover, we showcase performance of ENOT for varying cost functions on different tasks such as image generation, showing robustness of proposed algorithm. OTT-JAX library includes our implementation of ENOT algorithm https://ott-jax.readthedocs.io/en/latest/tutorials/ENOT.html
Paper Structure (25 sections, 2 theorems, 23 equations, 10 figures, 10 tables, 1 algorithm)

This paper contains 25 sections, 2 theorems, 23 equations, 10 figures, 10 tables, 1 algorithm.

Key Result

Lemma D.1

Let random vector $\xi$ have a compact support $\varOmega$ and $\forall x \in \varOmega$: $f_{n+1}(x) \geq f_{n}(x)$ be a sequence of continuous functions. Then from functional convergence $f_n \to f$ follows convergence of $f_n(\xi)$ to $f(\xi)$ with probability $1$.

Figures (10)

  • Figure 1: Fitting of three different transport maps $T_\theta$ between source and target measures in $\mathbb{R}^2$ with Euclidean cost function $c(x, y)= \| x - y \|$. We use the same number of iterations and MLP architecture for each method. Left: Sinkhorn divergence; Middle: Monge gap; Right: ENOT.
  • Figure 2: Recovered OT maps $T_\theta$ between synthetic measures on 2-sphere with geodesic cost $c(x,y)=\arccos (x^T y)$. All models are MLPs with outputs normalized to be on a unit sphere. Blue dots are the empirical source measure, red crosses are the empirical target measure and the orange crosses are the result of the found transport map. $\textbf{Left}$: Sinkhorn; $\textbf{Middle}$: Monge; $\textbf{Right}$: ENOT.
  • Figure 3: Left: Handbags to Shoes; Middle: FFHQ to Comics; Right: CelebA(f) to Anime; all images sizes are 128x128, the 1$^{\text{st}}$ row contains the source images, the 2$^{\text{nd}}$ row contains predicted generative mapping by ENOT; Cost function: $L^2$ divided by the image size.
  • Figure 4: Contour plots of $\mathcal{L}_2 ^\text{UV}$ dependence on the values of $\lambda$ and $\tau$ in Algorithm \ref{['alg:algorithm']} for the dimensions of $D = 256$ (Left, NaN values are greyed out), $D = 128$ (Middle), and $D = 64$ (Right).
  • Figure 5: Expectile regression. Left: the asymmetric squared loss $L_{\tau}$. The value $\tau = 0.5$ corresponds to the standard MSE loss, while $\tau = 0.9$ and $\tau = 0.99$ give more weight to the positive differences. Right: expectile models $f_{\tau}(x)$. The value $\tau = 0.5$ corresponds to the conditional statistical mean of the distribution, and when $\tau \to 1$ it approximates the maximum operator over the corresponding values of $y$.
  • ...and 5 more figures

Theorems & Definitions (3)

  • Lemma D.1: rudin1976principles
  • Theorem D.2
  • proof