Table of Contents
Fetching ...

On amortizing convex conjugates for optimal transport

Brandon Amos

TL;DR

This work tackles the difficulty of computing the convex conjugate in Euclidean Wasserstein-2 OT by proposing amortized conjugation combined with a fine-tuning solver. By predicting approximate conjugate solutions with a neural amortization model and refining them with exact solvers, the method achieves state-of-the-art transport maps on the Wasserstein-2 benchmark, while enabling successful modeling in diverse synthetic 2D settings. The key contributions are a systematic amortization framework for the conjugate, a taxonomy of amortization losses (objective-based, cycle-consistency, and regression-based), and practical insights on solver choices including parallel Armijo line searches. The approach improves stability and performance, reduces computational bottlenecks, and suggests broader applicability to other $c$-transforms and OT-related problems, albeit with limitations such as non-convex optimization challenges and limited convergence guarantees.

Abstract

This paper focuses on computing the convex conjugate (also known as the Legendre-Fenchel conjugate or c-transform) that appears in Euclidean Wasserstein-2 optimal transport. This conjugation is considered difficult to compute and in practice, methods are limited by not being able to exactly conjugate the dual potentials in continuous space. To overcome this, the computation of the conjugate can be approximated with amortized optimization, which learns a model to predict the conjugate. I show that combining amortized approximations to the conjugate with a solver for fine-tuning significantly improves the quality of transport maps learned for the Wasserstein-2 benchmark by Korotin et al. (2021a) and is able to model many 2-dimensional couplings and flows considered in the literature. All baselines, methods, and solvers are publicly available at http://github.com/facebookresearch/w2ot.

On amortizing convex conjugates for optimal transport

TL;DR

This work tackles the difficulty of computing the convex conjugate in Euclidean Wasserstein-2 OT by proposing amortized conjugation combined with a fine-tuning solver. By predicting approximate conjugate solutions with a neural amortization model and refining them with exact solvers, the method achieves state-of-the-art transport maps on the Wasserstein-2 benchmark, while enabling successful modeling in diverse synthetic 2D settings. The key contributions are a systematic amortization framework for the conjugate, a taxonomy of amortization losses (objective-based, cycle-consistency, and regression-based), and practical insights on solver choices including parallel Armijo line searches. The approach improves stability and performance, reduces computational bottlenecks, and suggests broader applicability to other -transforms and OT-related problems, albeit with limitations such as non-convex optimization challenges and limited convergence guarantees.

Abstract

This paper focuses on computing the convex conjugate (also known as the Legendre-Fenchel conjugate or c-transform) that appears in Euclidean Wasserstein-2 optimal transport. This conjugation is considered difficult to compute and in practice, methods are limited by not being able to exactly conjugate the dual potentials in continuous space. To overcome this, the computation of the conjugate can be approximated with amortized optimization, which learns a model to predict the conjugate. I show that combining amortized approximations to the conjugate with a solver for fine-tuning significantly improves the quality of transport maps learned for the Wasserstein-2 benchmark by Korotin et al. (2021a) and is able to model many 2-dimensional couplings and flows considered in the literature. All baselines, methods, and solvers are publicly available at http://github.com/facebookresearch/w2ot.
Paper Structure (31 sections, 13 equations, 13 figures, 8 tables, 5 algorithms)

This paper contains 31 sections, 13 equations, 13 figures, 8 tables, 5 algorithms.

Figures (13)

  • Figure 1: Conjugate amortization losses.
  • Figure 2: Conjugate solver convergence on the HD benchmarks with an ICNN potential.
  • Figure 3: Learned transport maps on synthetic settings from rout2021generative.
  • Figure 4: Learned potentials on settings considered in makkuva2020optimal.
  • Figure 5: Mesh grid ${\mathcal{G}}$ warped by the conjugate potential flow $\nabla f^\star$ from the top setting of \ref{['fig:makkuva']}.
  • ...and 8 more figures

Theorems & Definitions (24)

  • Remark 1
  • Remark 2
  • Remark 3
  • Remark 4
  • Remark 5
  • Remark 6
  • Remark 7
  • Remark 8
  • Remark 9
  • Remark 10
  • ...and 14 more