Table of Contents
Fetching ...

Dynamical Measure Transport and Neural PDE Solvers for Sampling

Jingtong Sun, Julius Berner, Lorenz Richter, Marius Zeinhofer, Johannes Müller, Kamyar Azizzadenesheli, Anima Anandkumar

TL;DR

The paper addresses the challenge of sampling from unnormalized densities with intractable normalizers by casting the task as dynamical measure transport governed by PDEs. It introduces a PDE-based framework where either SDEs or ODEs transport a prior density $p_{\mathrm{prior}}$ to a target $p_{\mathrm{target}}$, and uses physics-informed neural networks to solve the FP/CE residuals in a simulation-free manner. By offering constrained evolution strategies (e.g., annealing, time-reversal, Schrödinger-bridge/OT regularization) and integrating Gauss-Newton optimization for PINNs, the approach both recovers existing trajectory-based methods as special cases and achieves improved mode coverage on high-dimensional multimodal targets. The framework provides a flexible, scalable toolkit for sampling in scientific computing, with demonstrated potential for extensions to mean-field games and simulation-free dynamics learning.

Abstract

The task of sampling from a probability density can be approached as transporting a tractable density function to the target, known as dynamical measure transport. In this work, we tackle it through a principled unified framework using deterministic or stochastic evolutions described by partial differential equations (PDEs). This framework incorporates prior trajectory-based sampling methods, such as diffusion models or Schrödinger bridges, without relying on the concept of time-reversals. Moreover, it allows us to propose novel numerical methods for solving the transport task and thus sampling from complicated targets without the need for the normalization constant or data samples. We employ physics-informed neural networks (PINNs) to approximate the respective PDE solutions, implying both conceptional and computational advantages. In particular, PINNs allow for simulation- and discretization-free optimization and can be trained very efficiently, leading to significantly better mode coverage in the sampling task compared to alternative methods. Moreover, they can readily be fine-tuned with Gauss-Newton methods to achieve high accuracy in sampling.

Dynamical Measure Transport and Neural PDE Solvers for Sampling

TL;DR

The paper addresses the challenge of sampling from unnormalized densities with intractable normalizers by casting the task as dynamical measure transport governed by PDEs. It introduces a PDE-based framework where either SDEs or ODEs transport a prior density to a target , and uses physics-informed neural networks to solve the FP/CE residuals in a simulation-free manner. By offering constrained evolution strategies (e.g., annealing, time-reversal, Schrödinger-bridge/OT regularization) and integrating Gauss-Newton optimization for PINNs, the approach both recovers existing trajectory-based methods as special cases and achieves improved mode coverage on high-dimensional multimodal targets. The framework provides a flexible, scalable toolkit for sampling in scientific computing, with demonstrated potential for extensions to mean-field games and simulation-free dynamics learning.

Abstract

The task of sampling from a probability density can be approached as transporting a tractable density function to the target, known as dynamical measure transport. In this work, we tackle it through a principled unified framework using deterministic or stochastic evolutions described by partial differential equations (PDEs). This framework incorporates prior trajectory-based sampling methods, such as diffusion models or Schrödinger bridges, without relying on the concept of time-reversals. Moreover, it allows us to propose novel numerical methods for solving the transport task and thus sampling from complicated targets without the need for the normalization constant or data samples. We employ physics-informed neural networks (PINNs) to approximate the respective PDE solutions, implying both conceptional and computational advantages. In particular, PINNs allow for simulation- and discretization-free optimization and can be trained very efficiently, leading to significantly better mode coverage in the sampling task compared to alternative methods. Moreover, they can readily be fine-tuned with Gauss-Newton methods to achieve high accuracy in sampling.
Paper Structure (37 sections, 1 theorem, 74 equations, 5 figures, 4 tables)

This paper contains 37 sections, 1 theorem, 74 equations, 5 figures, 4 tables.

Key Result

Proposition 3.1

The BSDE versions of our losses are equivalent to previously existing losses in the following sense.

Figures (5)

  • Figure 1: We plot three evolutions of the process $X$ defined in \ref{['eq: SDE']} and \ref{['eq: ODE']} between a Gaussian prior density $p_{\mathrm{prior}}$ and a Gaussian mixture target density $p_{\mathrm{target}}$, corresponding to SDEs and ODEs which have been learned with three different loss functions. The top panel displays a stochastic evolution stemming from the loss $\mathcal{L}_\mathrm{logFP}^\mathrm{anneal}$, for which we additionally plot histograms of the prior and the target, respectively. In the second row we show deterministic evolutions, once obtained with $\mathcal{L}_\mathrm{logCE}^\mathrm{anneal}$ and once with $\mathcal{L}_\mathrm{logCE}$. Note that the stochastic and the left deterministic evolution follow the same annealing strategy, whereas the general loss $\mathcal{L}_\mathrm{logCE}$ leads to a different density path. We refer to \ref{['sec: learning the evolution']} for the details of the different methods.
  • Figure 2: The ground truth marginal in the first dimension and histograms of samples from our best performing method using the loss $\mathcal{L}_\mathrm{logCE}$ on the GMM (left) and many-well (right) examples.
  • Figure 3: Trajectories and marginals of our considered SDE-based methods for the GMM example. Note that we also show the corresponding ODE specified in \ref{['eq:sde_to_ode']} that can be used to evaluate the log-likelihoods, see \ref{['sec:likelihood']}. We provide an explanation for the suboptimal performance of $\mathcal{L}^{\mathrm{anneal}}_{\mathrm{logFP}}$ in \ref{['app: annealing challenges']}.
  • Figure 4: Trajectories and marginals of our considered ODE-based methods for the GMM example. We provide an explanation for the suboptimal performance of $\mathcal{L}^{\mathrm{anneal}}_{\mathrm{logCE}}$ in \ref{['app: annealing challenges']}.
  • Figure 5: We display different evolutions of the Gaussian prior to the 2-dimensional GMM target defined in \ref{['eq: GMM definition']}, once with a prescribed geometric annealing defined in \ref{['eq: geometric annealing']} and once learned via the general loss $\mathcal{L}_\mathrm{logCE}$ defined in \ref{['eq: def L_logCE']}.

Theorems & Definitions (5)

  • Remark 2.1: Connections to other methods
  • Proposition 3.1: Equivalence to trajectory-based methods
  • Remark 3.2: Numerical implications of PINN- and BSDE-based losses
  • Remark 3.3: Subtrajectory-based losses
  • proof : Proof of \ref{['prop: BSDE loss equivalences']}