Table of Contents
Fetching ...

A Truncated Newton Method for Optimal Transport

Mete Kemertas, Amir-massoud Farahmand, Allan D. Jepson

TL;DR

The paper addresses scalable, high-precision discrete OT with entropic regularization by developing a GPU-friendly truncated Newton method for the EOT dual. It introduces a discounted Hessian formulation and a hybrid Newton-Sinkhorn projection that achieves superlinear local convergence without requiring a Lipschitz Hessian, aided by adaptive temperature annealing within the MDOT framework. Key contributions include a specialized linear conjugate gradient routine for the dual, a TruncatedNewtonProject that combines Newton solves with Sinkhorn projections, and an adaptive initialization strategy that eliminates hyperparameter tuning. Empirical results on large-scale, high-dimensional problems demonstrate orders-of-magnitude speedups in wall-clock time at high precision, with a memory-efficient variant enabling problems up to ${n oughly 10^6}$. The work offers a practical route to high-accuracy OT on modern GPUs and provides a foundation for future global-convergence and stochastic-memory-efficient extensions.

Abstract

Developing a contemporary optimal transport (OT) solver requires navigating trade-offs among several critical requirements: GPU parallelization, scalability to high-dimensional problems, theoretical convergence guarantees, empirical performance in terms of precision versus runtime, and numerical stability in practice. With these challenges in mind, we introduce a specialized truncated Newton algorithm for entropic-regularized OT. In addition to proving that locally quadratic convergence is possible without assuming a Lipschitz Hessian, we provide strategies to maximally exploit the high rate of local convergence in practice. Our GPU-parallel algorithm exhibits exceptionally favorable runtime performance, achieving high precision orders of magnitude faster than many existing alternatives. This is evidenced by wall-clock time experiments on 24 problem sets (12 datasets $\times$ 2 cost functions). The scalability of the algorithm is showcased on an extremely large OT problem with $n \approx 10^6$, solved approximately under weak entopric regularization.

A Truncated Newton Method for Optimal Transport

TL;DR

The paper addresses scalable, high-precision discrete OT with entropic regularization by developing a GPU-friendly truncated Newton method for the EOT dual. It introduces a discounted Hessian formulation and a hybrid Newton-Sinkhorn projection that achieves superlinear local convergence without requiring a Lipschitz Hessian, aided by adaptive temperature annealing within the MDOT framework. Key contributions include a specialized linear conjugate gradient routine for the dual, a TruncatedNewtonProject that combines Newton solves with Sinkhorn projections, and an adaptive initialization strategy that eliminates hyperparameter tuning. Empirical results on large-scale, high-dimensional problems demonstrate orders-of-magnitude speedups in wall-clock time at high precision, with a memory-efficient variant enabling problems up to . The work offers a practical route to high-accuracy OT on modern GPUs and provides a foundation for future global-convergence and stochastic-memory-efficient extensions.

Abstract

Developing a contemporary optimal transport (OT) solver requires navigating trade-offs among several critical requirements: GPU parallelization, scalability to high-dimensional problems, theoretical convergence guarantees, empirical performance in terms of precision versus runtime, and numerical stability in practice. With these challenges in mind, we introduce a specialized truncated Newton algorithm for entropic-regularized OT. In addition to proving that locally quadratic convergence is possible without assuming a Lipschitz Hessian, we provide strategies to maximally exploit the high rate of local convergence in practice. Our GPU-parallel algorithm exhibits exceptionally favorable runtime performance, achieving high precision orders of magnitude faster than many existing alternatives. This is evidenced by wall-clock time experiments on 24 problem sets (12 datasets 2 cost functions). The scalability of the algorithm is showcased on an extremely large OT problem with , solved approximately under weak entopric regularization.

Paper Structure

This paper contains 27 sections, 17 theorems, 82 equations, 19 figures, 3 tables, 4 algorithms.

Key Result

Theorem 3.1

Assuming ${{\bm{c}} = {\bm{c}}(P)}$ and ${{\bm{d}}_{\bm{v}} = -P_{\bm{c}} {\bm{d}}_{\bm{u}}}$, define residuals ${{\bm{e}}_{\bm{u}}(\rho) {\coloneqq} F_{\bm{r}}(\rho) {\bm{d}}_{\bm{u}} + \nabla_{\bm{u}} g}$ (cf. (eq:bellman2)), and ${{\bm{e}} \coloneqq \nabla^2 g ~{\bm{d}} + \nabla g}$ (i.e., the Ne

Figures (19)

  • Figure 1: Ratio $\delta_k$ of actual to theoretically predicted reduction in $\left\lVert\nabla g_k\right\rVert_1$ per step for fixed (left) and adaptive (right) temperature decay (initialized with $q^{(1)} = q$). Each $\delta_k$ is the median at iteration $t$ of Alg. \ref{['alg:mdot']}. Shaded areas show 80% confidence intervals around median over 100 random problems from the upsampled MNIST dataset ($n=4096$) with normalized $L_1$ distance cost ( $\max_{i,j} |C_{i,j}|= 1$).
  • Figure 2: Comparison of median and 90th percentile performance for varying smoothing weight $w_{\bm{r}}$.
  • Figure 2: Error vs. wall-clock time for various algorithms. Each marker shows the optimality gap and time taken (median across 18 problems) until termination at a given hyperparameter setting, followed by rounding of the output onto ${\mathcal{U}}({\bm{r}}, {\bm{c}})$ via Alg. 2 of altschuler2017near. Upsampled MNIST (top) and color transfer (bottom) problem sets ($n=4096$) using $L_1$(left) and $L_2^2$(right) distance costs. MDOT--TruncatedNewton outperforms others by orders of magnitude at high precision and exhibits much better practical dependence on error than best known theoretical rates $\widetilde{O}(n^2 \varepsilon^{-1})$.
  • Figure 3: The MDOT--TruncatedNewton algorithm applied to a large-scale color transfer problem on $1024 \times 1024$ images ($n=2^{20}$). For this visualization, the cost matrix is given by the $L_2^2$ distance in RGB color space, normalized so that $\max_{ij} |C_{ij}| = 1$. Final temperature is $1/\gamma_{\mathrm{f}} = 2^{-10}$. Source images (top row) were generated with DALL-E 2. This figure is best viewed digitally.
  • Figure 4: Log-log plot of wall-clock time for MDOT-TruncatedNewton vs. problem size $n$. Each marker shows the median over 60 random problems from the MNIST (top) and color transfer (bottom) problem sets with normalized $L_1$(left) and $L_2^2$(right) distance costs. Error bars show $10^{th}$ and $90^{th}$ percentiles. For all problems, $\gamma_{\mathrm{f}} = 2^{12}$. Dashed lines show a polynomial $f(n) = an^2$, where $a$ is selected so that $an^2$ equals the median time taken at the largest $n$ considered. Above, the algorithm behaves no worse than $O(n^2)$.
  • ...and 14 more figures

Theorems & Definitions (28)

  • Theorem 3.1: Forcing sequence under discounting
  • Theorem 3.2: Convergence of Algorithm \ref{['alg:newton-solve']}
  • Lemma 3.2: Convergence of Algorithm \ref{['alg:partial-sinkhorn']}
  • Corollary 3.2: Per-step Cost of Algorithm \ref{['alg:tns-project']}
  • Theorem 3.3: Per-step Improvement of Algorithm \ref{['alg:tns-project']}
  • Lemma A.0: Properties of $\Prc$
  • proof
  • Lemma A.0: Properties of the coefficient matrix $\Frr$
  • proof
  • Lemma A.0: Convergence to the stationary distribution under $\Prc$
  • ...and 18 more