Table of Contents
Fetching ...

A Statistical Learning Perspective on Semi-dual Adversarial Neural Optimal Transport Solvers

Roman Tarasov, Petr Mokrov, Milena Gazdieva, Evgeny Burnaev, Alexander Korotin

TL;DR

This work establishes upper bounds on the generalization error of an approximate OT map recovered by the minimax quadratic OT solver, and believes that similar bounds could be derived for general OT case, paving the promising direction for future research.

Abstract

Neural network-based optimal transport (OT) is a recent and fruitful direction in the generative modeling community. It finds its applications in various fields such as domain translation, image super-resolution, computational biology and others. Among the existing OT approaches, of considerable interest are adversarial minimax solvers based on semi-dual formulations of OT problems. While promising, these methods lack theoretical investigation from a statistical learning perspective. Our work fills this gap by establishing upper bounds on the generalization error of an approximate OT map recovered by the minimax quadratic OT solver. Importantly, the bounds we derive depend solely on some standard statistical and mathematical properties of the considered functional classes (neural nets). While our analysis focuses on the quadratic OT, we believe that similar bounds could be derived for general OT case, paving the promising direction for future research. Our experimental illustrations are available online https://github.com/milenagazdieva/StatOT.

A Statistical Learning Perspective on Semi-dual Adversarial Neural Optimal Transport Solvers

TL;DR

This work establishes upper bounds on the generalization error of an approximate OT map recovered by the minimax quadratic OT solver, and believes that similar bounds could be derived for general OT case, paving the promising direction for future research.

Abstract

Neural network-based optimal transport (OT) is a recent and fruitful direction in the generative modeling community. It finds its applications in various fields such as domain translation, image super-resolution, computational biology and others. Among the existing OT approaches, of considerable interest are adversarial minimax solvers based on semi-dual formulations of OT problems. While promising, these methods lack theoretical investigation from a statistical learning perspective. Our work fills this gap by establishing upper bounds on the generalization error of an approximate OT map recovered by the minimax quadratic OT solver. Importantly, the bounds we derive depend solely on some standard statistical and mathematical properties of the considered functional classes (neural nets). While our analysis focuses on the quadratic OT, we believe that similar bounds could be derived for general OT case, paving the promising direction for future research. Our experimental illustrations are available online https://github.com/milenagazdieva/StatOT.

Paper Structure

This paper contains 29 sections, 24 theorems, 120 equations, 7 figures.

Key Result

Theorem 4.1

Let $\mathcal{F}$ be a class of $\beta$-strongly convex functions, then

Figures (7)

  • Figure 1: Monge's formulation of optimal transport.
  • Figure 2: Continuous setup of OT problem.
  • Figure 3: Convergence rates of the OT solver learned with the quadratic transport cost and a limited number of empirical training samples.
  • Figure 4: Empirical approximation error of the OT solver learned with the quadratic transport cost and using shallow NN architectures.
  • Figure 5: Estimation errors of the baseline solvers using a limited number of available empirical training samples $N,M$.
  • ...and 2 more figures

Theorems & Definitions (48)

  • Theorem 4.1: Error decomposition
  • Theorem 4.2: Rademacher Bound on the Estimation Error
  • Theorem 4.3: Inner approximation error
  • Proposition 4.4: Class $\mathcal{F}$ in practice
  • Remark 4.5
  • Theorem 4.6: Outer Approximation Error
  • Corollary 4.7
  • Remark 4.8
  • Theorem 4.9: Bound on the Generalization Error
  • Corollary 4.10: Generalization Error for the Specific Classes of Neural Networks
  • ...and 38 more