Table of Contents
Fetching ...

Meta Flow Matching: Integrating Vector Fields on the Wasserstein Manifold

Lazar Atanackovic, Xi Zhang, Brandon Amos, Mathieu Blanchette, Leo J. Lee, Yoshua Bengio, Alexander Tong, Kirill Neklyudov

TL;DR

Meta Flow Matching (MFM) extends Flow Matching to learn time-evolving distributions on the Wasserstein manifold by amortizing the conditional velocity field over initial populations. It conditions the velocity field on a population embedding $\varphi(p_0)$ produced by a Graph Neural Network, enabling generalization to unseen initial distributions and perturbation contexts. The approach is validated on synthetic letter denoising and a large-scale organoid drug-screen dataset, where it improves prediction of treatment responses across replicates and unseen patients, outperforming non-amortized baselines and competing with OT-enabled variants. By unifying interacting-particle dynamics with density-based generative modeling, MFM offers a path toward personalized predictions in biology and other natural-science domains.

Abstract

Numerous biological and physical processes can be modeled as systems of interacting entities evolving continuously over time, e.g. the dynamics of communicating cells or physical particles. Learning the dynamics of such systems is essential for predicting the temporal evolution of populations across novel samples and unseen environments. Flow-based models allow for learning these dynamics at the population level - they model the evolution of the entire distribution of samples. However, current flow-based models are limited to a single initial population and a set of predefined conditions which describe different dynamics. We argue that multiple processes in natural sciences have to be represented as vector fields on the Wasserstein manifold of probability densities. That is, the change of the population at any moment in time depends on the population itself due to the interactions between samples. In particular, this is crucial for personalized medicine where the development of diseases and their respective treatment response depend on the microenvironment of cells specific to each patient. We propose Meta Flow Matching (MFM), a practical approach to integrate along these vector fields on the Wasserstein manifold by amortizing the flow model over the initial populations. Namely, we embed the population of samples using a Graph Neural Network (GNN) and use these embeddings to train a Flow Matching model. This gives MFM the ability to generalize over the initial distributions, unlike previously proposed methods. We demonstrate the ability of MFM to improve the prediction of individual treatment responses on a large-scale multi-patient single-cell drug screen dataset.

Meta Flow Matching: Integrating Vector Fields on the Wasserstein Manifold

TL;DR

Meta Flow Matching (MFM) extends Flow Matching to learn time-evolving distributions on the Wasserstein manifold by amortizing the conditional velocity field over initial populations. It conditions the velocity field on a population embedding produced by a Graph Neural Network, enabling generalization to unseen initial distributions and perturbation contexts. The approach is validated on synthetic letter denoising and a large-scale organoid drug-screen dataset, where it improves prediction of treatment responses across replicates and unseen patients, outperforming non-amortized baselines and competing with OT-enabled variants. By unifying interacting-particle dynamics with density-based generative modeling, MFM offers a path toward personalized predictions in biology and other natural-science domains.

Abstract

Numerous biological and physical processes can be modeled as systems of interacting entities evolving continuously over time, e.g. the dynamics of communicating cells or physical particles. Learning the dynamics of such systems is essential for predicting the temporal evolution of populations across novel samples and unseen environments. Flow-based models allow for learning these dynamics at the population level - they model the evolution of the entire distribution of samples. However, current flow-based models are limited to a single initial population and a set of predefined conditions which describe different dynamics. We argue that multiple processes in natural sciences have to be represented as vector fields on the Wasserstein manifold of probability densities. That is, the change of the population at any moment in time depends on the population itself due to the interactions between samples. In particular, this is crucial for personalized medicine where the development of diseases and their respective treatment response depend on the microenvironment of cells specific to each patient. We propose Meta Flow Matching (MFM), a practical approach to integrate along these vector fields on the Wasserstein manifold by amortizing the flow model over the initial populations. Namely, we embed the population of samples using a Graph Neural Network (GNN) and use these embeddings to train a Flow Matching model. This gives MFM the ability to generalize over the initial distributions, unlike previously proposed methods. We demonstrate the ability of MFM to improve the prediction of individual treatment responses on a large-scale multi-patient single-cell drug screen dataset.
Paper Structure (43 sections, 3 theorems, 34 equations, 9 figures, 9 tables, 2 algorithms)

This paper contains 43 sections, 3 theorems, 34 equations, 9 figures, 9 tables, 2 algorithms.

Key Result

Proposition 1

Meta Flow Matching recovers the Conditional Generation via Flow Matching when the conditional dependence of the marginals $p_0(x_0 \,|\, c) = \int dx_1 \pi(x_0,x_1\,|\, c)$ and $p_1(x_1 \,|\, c) = \int dx_0 \pi(x_0,x_1\,|\, c)$ and the distribution $p(c)$ are known, i.e. there exist $\varphi: \mathc

Figures (9)

  • Figure 1: Illustration of Meta Flow Matching (MFM, \ref{['eq:mfm_loss']}). (a) Comparison between Flow Matching (FM, \ref{['eq:loss_fm']}) and MFM. (b) Depiction of differences between MFM and FM generated predictions. Given a point $x_t$, a vector field (flow) model trained with MFM can generate different points $\hat{x}_1$ for different initial distributions $p_0$ (represented by red, green, and purple). FM trained models can only predict an aggregate response over populations (shown in gray). FM at best can incorporate known (seen) conditional information available in the training data, denoted as Conditional Generative Flow Matching (CGFM, \ref{['eq:loss_cgfm']}). In contrast, MFM jointly learns a population embedding model $\varphi(p_0)$ and a vector field $v_t$, allowing generalization to unseen populations.
  • Figure 2: Illustration of flow matching methods on the 2-Wasserstein manifold, $\mathcal{P}_2(\mathcal{X})$, depicted as a two-dimensional sphere.Flow Matching learns the tangent vectors to a single curve on the manifold. Conditional generation corresponds to learning a finite set of curves on the manifold, e.g. classes $c_1$ and $c_2$ on the plot. Meta Flow Matching learns to integrate a vector field on $\mathcal{P}_2(\mathcal{X})$, i.e. for every starting density $p_0$, MFM defines a push-forward measure that integrates along the underlying vector field.
  • Figure 3: Synthetic letters experiment visualizations. Examples of model-generated samples from the source distribution ($t=0$) to predicted target distribution ($t=1$). See \ref{['fig:letters_full']}\ref{['ap:extended_results']} for further examples.
  • Figure 4: Organoid drug-screen dataset overview. (Left) a given replica consists of a control distribution $p_0$ and corresponding treatment response distribution $p_1$ for treatment condition $c_i$. (Right) train and test data splits for replicates (top) and patients (bottom). Experiments are conducted for 11 treatments, 10 patients, 3 culture conditions, and repeated (replicated) numerous times, resulting in a dataset of many control and treated population pairs.
  • Figure 5: Synthetic letters ablation over number of training populations. Here we fix the test sets (X's and Y's) with 10 random rotations (same across each experiment). We then ablate the number of populations used for training FM (red), CGFM (green), and MFM (blue), by changing the number of random rotations/orientations used for each letter silhouette. We observe that for MFM the distributional errors on the test sets consistently decrease as we increase the number of training populations. In contrast, since FM and CGFM cannot generalize across novel populations, this increase in training populations does not lead to an overall improved performance on the test populations. For a large number of training populations, CGFM exhibits exhaustive memory requirements since it requires one-hot encodings as input conditions to denote the population index.
  • ...and 4 more figures

Theorems & Definitions (8)

  • Example 1: Mean-field limit of interacting particles
  • Example 2: Diffusion
  • Proposition 1
  • proof
  • Theorem 1
  • proof
  • Theorem 1
  • proof