Table of Contents
Fetching ...

Oh SnapMMD! Forecasting Stochastic Dynamics Beyond the Schrödinger Bridge's End

Renato Berlinghieri, Yunyi Shen, Jialong Jiang, Tamara Broderick

TL;DR

The paper tackles forecasting stochastic dynamics from snapshot data, where trajectories are unobserved, by introducing SnapMMD, a framework that learns SDEs through direct joint state-time distribution matching using Maximum Mean Discrepancy. This approach enables inference of unknown, state-dependent volatility, handles incomplete observations, and yields an interpretable velocity field plus an RKHS-based $R^2$ diagnostic for model fit. Across synthetic and real datasets—including Lotka–Volterra, repressilator variants, Gulf of Mexico currents, and PBMC immune activation—the method delivers superior forecasting and competitive interpolation against Schrödinger-bridge baselines while providing robust diagnostics and interpretable dynamics. The work has practical impact for analyzing time-course data in biology and related dynamic systems, with code available for reproducibility and further application.

Abstract

Scientists often want to make predictions beyond the observed time horizon of "snapshot" data following latent stochastic dynamics. For example, in time course single-cell mRNA profiling, scientists have access to cellular transcriptional state measurements (snapshots) from different biological replicates at different time points, but they cannot access the trajectory of any one cell because measurement destroys the cell. Researchers want to forecast (e.g.) differentiation outcomes from early state measurements of stem cells. Recent Schrödinger-bridge (SB) methods are natural for interpolating between snapshots. But past SB papers have not addressed forecasting -- likely since existing methods either (1) reduce to following pre-set reference dynamics (chosen before seeing data) or (2) require the user to choose a fixed, state-independent volatility since they minimize a Kullback-Leibler divergence. Either case can lead to poor forecasting quality. In the present work, we propose a new framework, SnapMMD, that learns dynamics by directly fitting the joint distribution of both state measurements and observation time with a maximum mean discrepancy (MMD) loss. Unlike past work, our method allows us to infer unknown and state-dependent volatilities from the observed data. We show in a variety of real and synthetic experiments that our method delivers accurate forecasts. Moreover, our approach allows us to learn in the presence of incomplete state measurements and yields an $R^2$-style statistic that diagnoses fit. We also find that our method's performance at interpolation (and general velocity-field reconstruction) is at least as good as (and often better than) state-of-the-art in almost all of our experiments.

Oh SnapMMD! Forecasting Stochastic Dynamics Beyond the Schrödinger Bridge's End

TL;DR

The paper tackles forecasting stochastic dynamics from snapshot data, where trajectories are unobserved, by introducing SnapMMD, a framework that learns SDEs through direct joint state-time distribution matching using Maximum Mean Discrepancy. This approach enables inference of unknown, state-dependent volatility, handles incomplete observations, and yields an interpretable velocity field plus an RKHS-based diagnostic for model fit. Across synthetic and real datasets—including Lotka–Volterra, repressilator variants, Gulf of Mexico currents, and PBMC immune activation—the method delivers superior forecasting and competitive interpolation against Schrödinger-bridge baselines while providing robust diagnostics and interpretable dynamics. The work has practical impact for analyzing time-course data in biology and related dynamic systems, with code available for reproducibility and further application.

Abstract

Scientists often want to make predictions beyond the observed time horizon of "snapshot" data following latent stochastic dynamics. For example, in time course single-cell mRNA profiling, scientists have access to cellular transcriptional state measurements (snapshots) from different biological replicates at different time points, but they cannot access the trajectory of any one cell because measurement destroys the cell. Researchers want to forecast (e.g.) differentiation outcomes from early state measurements of stem cells. Recent Schrödinger-bridge (SB) methods are natural for interpolating between snapshots. But past SB papers have not addressed forecasting -- likely since existing methods either (1) reduce to following pre-set reference dynamics (chosen before seeing data) or (2) require the user to choose a fixed, state-independent volatility since they minimize a Kullback-Leibler divergence. Either case can lead to poor forecasting quality. In the present work, we propose a new framework, SnapMMD, that learns dynamics by directly fitting the joint distribution of both state measurements and observation time with a maximum mean discrepancy (MMD) loss. Unlike past work, our method allows us to infer unknown and state-dependent volatilities from the observed data. We show in a variety of real and synthetic experiments that our method delivers accurate forecasts. Moreover, our approach allows us to learn in the presence of incomplete state measurements and yields an -style statistic that diagnoses fit. We also find that our method's performance at interpolation (and general velocity-field reconstruction) is at least as good as (and often better than) state-of-the-art in almost all of our experiments.

Paper Structure

This paper contains 61 sections, 2 theorems, 24 equations, 26 figures, 21 tables.

Key Result

Proposition 3.1

Let $f(\bm{y},t) = f(\bm{y}\mid t) h(t)$ and $g(\bm{y},t) = g(\bm{y}\mid t) h(t)$ be joint distributions over $\bm{y}\in \mathbb{R}^d$ (for dimension $d$) and discrete time $t \in \mathcal{T}$, where $h(t)$ is a probability mass function and $f(\bm{y}\mid t), g(\bm{y}\mid t)$ are conditional distri

Figures (26)

  • Figure 1: Lotka-Volterra results (\ref{['sec:lv-main']}). We show 200 samples at each of 10 training times and 1 forecast time (red). Forecast points overlap with the training points at time 0 (blue).
  • Figure 2: Repressilator results: mRNA-only (upper, \ref{['sec:repr-main']}) and mRNA and protein (lower, \ref{['sec:repr-protein']}). We show 200 samples at each of 10 training times and 1 forecast time (red).
  • Figure 3: Gulf of Mexico results (\ref{['sec:gom-main']}). We show 200 samples at each of 10 training times and 1 forecast time (red).
  • Figure 4: PBMC results (\ref{['sec:pbmc']}). The axes in every plot are the same three principal components, computed over the full data: i.e., 41 time steps of the 30-dimensional gene programs. Leftmost four panels: evolution of the training data at time steps 1, 7, 14, and 20. "Truth" panel: ground truth snapshot at time step 21. Rightmost three panels: model forecasts at time step 21.
  • Figure A5: Experimental results for the Lotka-Volterra system. Top row: forecast prediction task. A method is successful if the forecast predicted points (in red) match the red points in the ground truth figure. Middle row: ground truth vector field (left) and reconstructed vector fields with the three methods. Bottom row: Difference between reconstructed vector fields and ground truth. For each point of interest on the grid, we represent the difference between the two vectors with an arrow and color it according to the magnitude of the difference (colorbar to the right).
  • ...and 21 more figures

Theorems & Definitions (5)

  • Proposition 3.1
  • Definition 3.1: RKHS-based $R^2$ metric
  • Proposition B.1: Forecasting limitation of fixed-reference SB methods
  • proof
  • proof : Proof of \ref{['prop:MMDdecomposition']}