Multimarginal generative modeling with stochastic interpolants
Michael S. Albergo, Nicholas M. Boffi, Michael Lindsey, Eric Vanden-Eijnden
TL;DR
The paper addresses learning joint distributions over $K+1$ marginals to reveal multi-way correspondences by extending stochastic interpolants to a barycentric multimarginal framework on the simplex $\Delta^{K}$. It introduces $x(\alpha)=\sum_{k=0}^K\alpha_k x_k$ with a density $\rho(\alpha,x)$ satisfying $K+1$ continuity equations and conditional-mean fields $g_k(\alpha,x)=\mathbb{E}[x_k\mid x(\alpha)=x]$, enabling velocity fields learned via simple quadratic objectives. A key contribution is decoupling the path $\alpha(t)$ from velocity learning, and showing how to minimize transport cost and extract multi-way correspondences, including all-to-all image translation and emergent style transfer, within a unified multimarginal framework. The framework also yields deterministic couplings and practical transport on the simplex, with numerical demonstrations showing reduced transport cost and rich cross-marginal interactions, suggesting applications in data decorruption and algorithmic fairness. Overall, the work broadens generative modeling by enabling joint generation across multiple marginals with controllable, cost-efficient transports and interpretable multi-way correspondences.
Abstract
Given a set of $K$ probability densities, we consider the multimarginal generative modeling problem of learning a joint distribution that recovers these densities as marginals. The structure of this joint distribution should identify multi-way correspondences among the prescribed marginals. We formalize an approach to this task within a generalization of the stochastic interpolant framework, leading to efficient learning algorithms built upon dynamical transport of measure. Our generative models are defined by velocity and score fields that can be characterized as the minimizers of simple quadratic objectives, and they are defined on a simplex that generalizes the time variable in the usual dynamical transport framework. The resulting transport on the simplex is influenced by all marginals, and we show that multi-way correspondences can be extracted. The identification of such correspondences has applications to style transfer, algorithmic fairness, and data decorruption. In addition, the multimarginal perspective enables an efficient algorithm for reducing the dynamical transport cost in the ordinary two-marginal setting. We demonstrate these capacities with several numerical examples.
