Table of Contents
Fetching ...

On Fitting Flow Models with Large Sinkhorn Couplings

Stephen Zhang, Alireza Mousavi-Hosseini, Michal Klein, Marco Cuturi

TL;DR

This work advances flow-based generative modeling by leveraging large-scale entropic optimal transport couplings to pair source and target samples for training velocity fields. It unifies independent-flow matching and Batch-OT FM through a continuous OT framework controlled by the entropic regularization parameter $\varepsilon$, and introduces a scale-free renormalized coupling entropy $\mathcal{E}$ to benchmark coupling sharpness. Practical innovations—dot-product costs, warm-starts, PCA acceleration, and multi-GPU sharding—enable fitting and sampling with couplings on millions of points, revealing that large batch OT (with modest $\mathcal{E}$ around 0.1) consistently improves curvature, reconstruction, and FID in synthetic and real-data tasks. The results provide actionable guidance for applying OT-based flow training at scale, suggesting that pre-evaluating large Sinkhorn couplings can yield faster, straighter flows and better generation with manageable computational overhead.

Abstract

Flow models transform data gradually from one modality (e.g. noise) onto another (e.g. images). Such models are parameterized by a time-dependent velocity field, trained to fit segments connecting pairs of source and target points. When the pairing between source and target points is given, training flow models boils down to a supervised regression problem. When no such pairing exists, as is the case when generating data from noise, training flows is much harder. A popular approach lies in picking source and target points independently. This can, however, lead to velocity fields that are slow to train, but also costly to integrate at inference time. In theory, one would greatly benefit from training flow models by sampling pairs from an optimal transport (OT) measure coupling source and target, since this would lead to a highly efficient flow solving the Benamou and Brenier dynamical OT problem. In practice, recent works have proposed to sample mini-batches of $n$ source and $n$ target points and reorder them using an OT solver to form better pairs. These works have advocated using batches of size $n\approx 256$, and considered OT solvers that return couplings that are either sharp (using e.g. the Hungarian algorithm) or blurred (using e.g. entropic regularization, a.k.a. Sinkhorn). We follow in the footsteps of these works by exploring the benefits of increasing $n$ by three to four orders of magnitude, and look more carefully on the effect of the entropic regularization $\varepsilon$ used in the Sinkhorn algorithm. Our analysis is facilitated by new scale invariant quantities to report the sharpness of a coupling, while our sharded computations across multiple GPU or GPU nodes allow scaling up $n$. We show that in both synthetic and image generation tasks, flow models greatly benefit when fitted with large Sinkhorn couplings, with a low entropic regularization $\varepsilon$.

On Fitting Flow Models with Large Sinkhorn Couplings

TL;DR

This work advances flow-based generative modeling by leveraging large-scale entropic optimal transport couplings to pair source and target samples for training velocity fields. It unifies independent-flow matching and Batch-OT FM through a continuous OT framework controlled by the entropic regularization parameter , and introduces a scale-free renormalized coupling entropy to benchmark coupling sharpness. Practical innovations—dot-product costs, warm-starts, PCA acceleration, and multi-GPU sharding—enable fitting and sampling with couplings on millions of points, revealing that large batch OT (with modest around 0.1) consistently improves curvature, reconstruction, and FID in synthetic and real-data tasks. The results provide actionable guidance for applying OT-based flow training at scale, suggesting that pre-evaluating large Sinkhorn couplings can yield faster, straighter flows and better generation with manageable computational overhead.

Abstract

Flow models transform data gradually from one modality (e.g. noise) onto another (e.g. images). Such models are parameterized by a time-dependent velocity field, trained to fit segments connecting pairs of source and target points. When the pairing between source and target points is given, training flow models boils down to a supervised regression problem. When no such pairing exists, as is the case when generating data from noise, training flows is much harder. A popular approach lies in picking source and target points independently. This can, however, lead to velocity fields that are slow to train, but also costly to integrate at inference time. In theory, one would greatly benefit from training flow models by sampling pairs from an optimal transport (OT) measure coupling source and target, since this would lead to a highly efficient flow solving the Benamou and Brenier dynamical OT problem. In practice, recent works have proposed to sample mini-batches of source and target points and reorder them using an OT solver to form better pairs. These works have advocated using batches of size , and considered OT solvers that return couplings that are either sharp (using e.g. the Hungarian algorithm) or blurred (using e.g. entropic regularization, a.k.a. Sinkhorn). We follow in the footsteps of these works by exploring the benefits of increasing by three to four orders of magnitude, and look more carefully on the effect of the entropic regularization used in the Sinkhorn algorithm. Our analysis is facilitated by new scale invariant quantities to report the sharpness of a coupling, while our sharded computations across multiple GPU or GPU nodes allow scaling up . We show that in both synthetic and image generation tasks, flow models greatly benefit when fitted with large Sinkhorn couplings, with a low entropic regularization .

Paper Structure

This paper contains 22 sections, 5 theorems, 21 equations, 16 figures, 3 tables, 2 algorithms.

Key Result

Theorem 1

If $\mu \in {\mathcal{P}}_{2}({\mathbb{R}}^d)$ has an absolutely continuous density then eq:monge is solved by a map $T^\star$ of the form $T^\star=\nabla u$, where $u:{\mathbb{R}}^d\rightarrow {\mathbb{R}}$ is convex. Moreover if $u$ is a convex potential that is such that $\nabla u_\#\mu=\nu$ then

Figures (16)

  • Figure 1: Samples generated from models trained on ImageNet-64. $n$ denotes the total OT batch size. We use $\varepsilon=0.1$ and the Euler solver (Dopri5 for adaptive with NFE $\approx 270$). More samples provided in \ref{['fig:imagenet64_grid']}.
  • Figure 2: Results on the piecewise affine OT Map benchmark. The three top rows present (in that order) curvature, reconstruction and BPD metrics. Below, we provide compute times associated with running the Sinkhorn64 algorithm as a per-example cost. This per-example cost is the total time needed to run Sinkhorn64 to get $n\times n$ coupling divided by $n$. That cost would be 0 when using I-FM. We observe across all dimensions improvements of all metrics.
  • Figure 3: Results on the Korotin benchmark. As with Figure \ref{['fig:piecewise']}, we compute curvature and reconstruction metrics, and compute times below. Some of the runs for largest OT batch sizes $n$ are provided in the supplementary. These runs suggest that to train OT models in these dimensions increasing $n$ is overall beneficial across the board.
  • Figure 4: Experiment metrics for CIFAR-10 image generation. We evaluate the trained models using the Euler solver with three different number of steps, and with the Dopri5 solver and adaptive steps. The plots demonstrate the benefits of a larger OT batch size to achieve significantly smaller curvature, and moderately smaller FID at low number of integration steps. CIFAR-10 is not necessarily the best setup to evaluate the performance of OT based FM, since the number of points is relatively low (the batch sizes we consider involve in fact resampling data). Our experiments also suggest that in this setting, lower renormalized entropy generally benefits the performance.
  • Figure 5: ImageNet-32 experiment metrics. We observe that both FID and curvature are smaller when using larger OT batch size, and smaller renormalized entropy tends to result in better metrics.
  • ...and 11 more figures

Theorems & Definitions (8)

  • Theorem 1: Bre91
  • Proposition 2
  • Lemma 3
  • proof
  • Proposition 6
  • Lemma 7
  • proof
  • proof : Proof of \ref{['prop:var_lower_bound']}