Table of Contents
Fetching ...

Learning diffusion at lightspeed

Antonio Terpin, Nicolas Lanzetti, Martin Gadea, Florian Dörfler

TL;DR

A new simple model, JKOnet*, is proposed, which bypasses the complexity of existing architectures while presenting significantly enhanced representational capabilities: JKOnet* recovers the potential, interaction, and internal energy components of the underlying diffusion process.

Abstract

Diffusion regulates numerous natural processes and the dynamics of many successful generative models. Existing models to learn the diffusion terms from observational data rely on complex bilevel optimization problems and model only the drift of the system. We propose a new simple model, JKOnet*, which bypasses the complexity of existing architectures while presenting significantly enhanced representational capabilities: JKOnet* recovers the potential, interaction, and internal energy components of the underlying diffusion process. JKOnet* minimizes a simple quadratic loss and outperforms other baselines in terms of sample efficiency, computational complexity, and accuracy. Additionally, JKOnet* provides a closed-form optimal solution for linearly parametrized functionals, and, when applied to predict the evolution of cellular processes from real-world data, it achieves state-of-the-art accuracy at a fraction of the computational cost of all existing methods. Our methodology is based on the interpretation of diffusion processes as energy-minimizing trajectories in the probability space via the so-called JKO scheme, which we study via its first-order optimality conditions.

Learning diffusion at lightspeed

TL;DR

A new simple model, JKOnet*, is proposed, which bypasses the complexity of existing architectures while presenting significantly enhanced representational capabilities: JKOnet* recovers the potential, interaction, and internal energy components of the underlying diffusion process.

Abstract

Diffusion regulates numerous natural processes and the dynamics of many successful generative models. Existing models to learn the diffusion terms from observational data rely on complex bilevel optimization problems and model only the drift of the system. We propose a new simple model, JKOnet*, which bypasses the complexity of existing architectures while presenting significantly enhanced representational capabilities: JKOnet* recovers the potential, interaction, and internal energy components of the underlying diffusion process. JKOnet* minimizes a simple quadratic loss and outperforms other baselines in terms of sample efficiency, computational complexity, and accuracy. Additionally, JKOnet* provides a closed-form optimal solution for linearly parametrized functionals, and, when applied to predict the evolution of cellular processes from real-world data, it achieves state-of-the-art accuracy at a fraction of the computational cost of all existing methods. Our methodology is based on the interpretation of diffusion processes as energy-minimizing trajectories in the probability space via the so-called JKO scheme, which we study via its first-order optimality conditions.
Paper Structure (93 sections, 3 theorems, 71 equations, 11 figures, 1 table)

This paper contains 93 sections, 3 theorems, 71 equations, 11 figures, 1 table.

Key Result

Proposition 3.1

Assume $V$ is continuously differentiable, lower bounded, and has a bounded Hessian. Then, the acr:jko scheme eq:jko has an optimal solution $\mu_{t+1}$ and, if $\mu_{t+1}$ is optimal for eq:jko, then there is an optimal transport plan $\gamma_t$ between $\mu_t$ and $\mu_{t+1}$ such that

Figures (11)

  • Figure 1: Given a sequence of snapshots $(\mu_0, \ldots, \mu_T)$ of a population of particles undergoing diffusion, we want to find the parameters $\theta$ of the parametrized energy function $J_\theta$ that best explains the particles evolution. Given $\theta$, the effects mismatch is the Wasserstein distance between the observed trajectory and the predicted trajectory obtained iteratively solving the *acr:jko step with $J_\theta$. The first-order optimality condition in lanzetti2024 applied to the *acr:jko step suggests that the "gradient" of $J_\theta$ with respect to each $\hat{\mu}_t$ vanishes at optimality, i.e., for $\textcolor{colorpred}{\hat{\mu_t}} = \textcolor{colortrue}{\mu_t}$. For $J_\theta(\mu) = \int_{\mathbb{R}^d} \textcolor{colorpred}{V_\theta}(x)\mathrm{d}\mu(x)$, this condition is depicted on the right. The gradient (dashed blue arrows) of the true $V$ (level curves in dashed blue) at each observed particle $x_i^{t+1}$ (blue circles) in the next snapshot ${\mu}_{t+1}$ opposes the displacement (dotted red arrows) from a particle ${x}_i^t$ (red triangles) in the previous snapshot $\mu_t$. Instead, the gradient (solid green arrows) of the estimated $V_\theta$ (level curves in solid green) at each observed particle $x_i^{t+1}$ (square) does not oppose the displacement from a particle ${x}_i^t$ in the previous snapshot $\mu_t$. This mismatch in the causes of the diffusion process is what JKOnet$\ast$ minimizes.
  • Figure 2: Level curves of the true (green-colored) and estimated (blue-colored) potentials \ref{['eq:styblinski-tang']}, \ref{['eq:flowers']}, \ref{['eq:ishigami']} and \ref{['eq:friedman']}, see \ref{['appendix:functionals']}. See also \ref{['fig:additional:level-sets-and-predictions']} in \ref{['appendix:eye-candies']}.
  • Figure 3: Numerical results of \ref{['sec:experiments:lightspeed']}. The scatter plot displays points $(x_i, y_i)$ where $x_i$ indexes the potentials in \ref{['appendix:functionals']} and $y_i$ are the errors (acr:emd, normalized so that the maximum error among all models and all potentials is $1$) obtained with the different models. We mark with NaN each method that has diverged during training. The plot on the bottom-left shows the acr:emd error trajectory during training (normalized such that $0$ and $1$ are the minimum and maximum acr:emd), averaged over all the experiments. The shaded area represents the standard deviation. The box plot analyses the time per epoch required by each method. The statistics are across all epochs and all potential energies.
  • Figure 4: Numerical results of \ref{['sec:experiments:scaling']}, reported in full in \ref{['fig:additional-scaling']} in \ref{['appendix:eye-candies']}. The colors represent the acr:emd error, which appears to scale sublinearly with the dimension $d$.
  • Figure 5: Visualizations of \ref{['sec:rna']}. The top row shows the two principal components of the scRNA-seq data, ground truth (green, days 1-3, 6-9, 12-15, 18-21, 24-27) and interpolated (blue, days 4-5, 10-11, 16-17, 22-23). The bottom row displays the estimated potential level curves over time. The bottom left plot superimposes the same three level curves for days 1-3 (solid), 12-15 (dashed), and 24-27 (dashed with larger spaces) to highlight the time-dependency.
  • ...and 6 more figures

Theorems & Definitions (8)

  • Example 2.1: Fokker-Planck
  • Example 2.2: Fokker-Plank as a Wasserstein gradient flow
  • Proposition 3.1: Potential energy
  • Proposition 3.2: General case
  • Remark 3.3
  • Proposition 3.4
  • Remark 3.5
  • Remark 3.6