Table of Contents
Fetching ...

Multimarginal generative modeling with stochastic interpolants

Michael S. Albergo, Nicholas M. Boffi, Michael Lindsey, Eric Vanden-Eijnden

TL;DR

The paper addresses learning joint distributions over $K+1$ marginals to reveal multi-way correspondences by extending stochastic interpolants to a barycentric multimarginal framework on the simplex $\Delta^{K}$. It introduces $x(\alpha)=\sum_{k=0}^K\alpha_k x_k$ with a density $\rho(\alpha,x)$ satisfying $K+1$ continuity equations and conditional-mean fields $g_k(\alpha,x)=\mathbb{E}[x_k\mid x(\alpha)=x]$, enabling velocity fields learned via simple quadratic objectives. A key contribution is decoupling the path $\alpha(t)$ from velocity learning, and showing how to minimize transport cost and extract multi-way correspondences, including all-to-all image translation and emergent style transfer, within a unified multimarginal framework. The framework also yields deterministic couplings and practical transport on the simplex, with numerical demonstrations showing reduced transport cost and rich cross-marginal interactions, suggesting applications in data decorruption and algorithmic fairness. Overall, the work broadens generative modeling by enabling joint generation across multiple marginals with controllable, cost-efficient transports and interpretable multi-way correspondences.

Abstract

Given a set of $K$ probability densities, we consider the multimarginal generative modeling problem of learning a joint distribution that recovers these densities as marginals. The structure of this joint distribution should identify multi-way correspondences among the prescribed marginals. We formalize an approach to this task within a generalization of the stochastic interpolant framework, leading to efficient learning algorithms built upon dynamical transport of measure. Our generative models are defined by velocity and score fields that can be characterized as the minimizers of simple quadratic objectives, and they are defined on a simplex that generalizes the time variable in the usual dynamical transport framework. The resulting transport on the simplex is influenced by all marginals, and we show that multi-way correspondences can be extracted. The identification of such correspondences has applications to style transfer, algorithmic fairness, and data decorruption. In addition, the multimarginal perspective enables an efficient algorithm for reducing the dynamical transport cost in the ordinary two-marginal setting. We demonstrate these capacities with several numerical examples.

Multimarginal generative modeling with stochastic interpolants

TL;DR

The paper addresses learning joint distributions over marginals to reveal multi-way correspondences by extending stochastic interpolants to a barycentric multimarginal framework on the simplex . It introduces with a density satisfying continuity equations and conditional-mean fields , enabling velocity fields learned via simple quadratic objectives. A key contribution is decoupling the path from velocity learning, and showing how to minimize transport cost and extract multi-way correspondences, including all-to-all image translation and emergent style transfer, within a unified multimarginal framework. The framework also yields deterministic couplings and practical transport on the simplex, with numerical demonstrations showing reduced transport cost and rich cross-marginal interactions, suggesting applications in data decorruption and algorithmic fairness. Overall, the work broadens generative modeling by enabling joint generation across multiple marginals with controllable, cost-efficient transports and interpretable multi-way correspondences.

Abstract

Given a set of probability densities, we consider the multimarginal generative modeling problem of learning a joint distribution that recovers these densities as marginals. The structure of this joint distribution should identify multi-way correspondences among the prescribed marginals. We formalize an approach to this task within a generalization of the stochastic interpolant framework, leading to efficient learning algorithms built upon dynamical transport of measure. Our generative models are defined by velocity and score fields that can be characterized as the minimizers of simple quadratic objectives, and they are defined on a simplex that generalizes the time variable in the usual dynamical transport framework. The resulting transport on the simplex is influenced by all marginals, and we show that multi-way correspondences can be extracted. The identification of such correspondences has applications to style transfer, algorithmic fairness, and data decorruption. In addition, the multimarginal perspective enables an efficient algorithm for reducing the dynamical transport cost in the ordinary two-marginal setting. We demonstrate these capacities with several numerical examples.
Paper Structure (17 sections, 7 theorems, 48 equations, 7 figures, 2 tables)

This paper contains 17 sections, 7 theorems, 48 equations, 7 figures, 2 tables.

Key Result

Theorem 1

For all $\alpha \in \Delta^{K}$, the probability distribution of the barycentric stochastic interpolant $x(\alpha)$ has a density $\rho(\alpha,x)$ which satisfies the $K+1$ equations Above, each $g_k(\alpha,x)$ is defined as the conditional expectation where $\mathbb E [ x_k | x(\alpha) = x]$ denotes an expectation over $\rho_0(x_0)\rho(x_1,\ldots,x_k)$ conditioned on the event $x(\alpha) = x$.

Figures (7)

  • Figure 1: Direct optimization of $\alpha(t)$ over a parametric class to reduce transport cost in the 2-marginal learning problem of a Gaussian to the checkerboard density. Left: The initial and final $\hat{\alpha}_0, \hat{\alpha}_1$ learned over 300 optimizations steps on \ref{['eq:ot:cost']}. Center: The reduction in the path length over training. Right: Time slices of the probability density $\bar{\rho}(t,x)$ corresponding to learned interpolant with learned $\hat{\alpha}$ as compared to the linear interpolant $\alpha = [1-t, t]$.
  • Figure 2: Left: Generated MNIST digits from the same Gaussian sample $x_0 \sim \rho_0$, with $K=7$ marginals ($\rho_0$ and 6 digit classes). $x_0$ is visualized in the center of the image collection at time $t=0$, and the perimeter corresponds to transport to the edge of the simplex at time $t=1$ with vertices color-coded. A Petrie polygon representing the 6-simplex, with arrows denoting transport from the Gaussians along edges to the color-coded marginals, clarifies the marginal endpoints. Right: Demonstrating the impact of learning with over the larger simplex. Top row: learning just on the simplex edge from $0$ to $3$. Middle: Learning on all the simplex edges from $0$ through $5$. Bottom: Learning on the entire simplex constructed from $0$ through $5$ and not just the edges.
  • Figure 3: Left: An illustration of how different transport paths on the simplex can reach final samples that have similar content. Top row: A cat sample is transformed into a celebrity. Middle row: The same cat is pushed to the flower marginal. Bottom row: The new marginal flower sample is then pushed to a celebrity that maintains meaningful semantic structure from the celebrity generated from along the other path on the simplex.
  • Figure 4: Left: Marginally sampling a $K=4$ multimarginal model comprised of the AFHQ, flowers, and CelebA datasets at resolution $64\times64$. Shown to the right of the images is the corresponding path taken on the simplex, with $\alpha(0) = e_0$ starting at the Gaussian $\rho_0$ and ending at one of $\rho_1, \rho_2$ or $\rho_3$. Right: Demonstration of style transfer that emerges naturally when learning a multimarginal interpolant. With a single shared interpolant, we flow from the AFHQ vertex $\rho_2$ to the flowers vertex $\rho_1$ or to the CelebA vertex $\rho_3$. The learned flow connects images with stylistic similarities.
  • Figure 5: The output of the probability flow \ref{['eq:pfode']} realized by for the learned interpolant for the checkerboard problem discussed in \ref{['sec:OT']} and \ref{['app:exp:ot']}. For fixed number of function evaluations, 5 steps of the midpoint integrator, the learned $\alpha$ is quicker to give a more accurate solution.
  • ...and 2 more figures

Theorems & Definitions (12)

  • Definition 1: Barycentric stochastic interpolant
  • Theorem 1: Continuity equations
  • Corollary 1: Transport equations
  • Corollary 1
  • Theorem 2
  • proof
  • Theorem 2: Continuity equations
  • proof
  • Corollary 2: Transport equations
  • proof
  • ...and 2 more