Table of Contents
Fetching ...

Improved sampling via learned diffusions

Lorenz Richter, Julius Berner

TL;DR

The paper addresses sampling from unnormalized densities using diffusion-based methods and presents a unifying path-space framework that treats the sampling task as time-reversal of controlled diffusions. By introducing divergences between path-space measures, notably the log-variance loss, it overcomes limitations of reverse KL objectives such as mode collapse and high variance. The approach connects Schrödinger bridges, diffusion-based generative modeling, and time-reversed reference-process methods into a single formalism, providing practical, differentiable losses that can be optimized with gradient-based methods. Empirical results on GMM, Funnel, and double-well benchmarks demonstrate that the log-variance loss yields improved stability, mode coverage, and sampling quality across DIS, PIS, and related methods, with strong performance in higher dimensions. Overall, the work offers a principled, scalable route to improved diffusion-based sampling for unnormalized targets and lays groundwork for problem-tailored divergences.

Abstract

Recently, a series of papers proposed deep learning-based approaches to sample from target distributions using controlled diffusion processes, being trained only on the unnormalized target densities without access to samples. Building on previous work, we identify these approaches as special cases of a generalized Schrödinger bridge problem, seeking a stochastic evolution between a given prior distribution and the specified target. We further generalize this framework by introducing a variational formulation based on divergences between path space measures of time-reversed diffusion processes. This abstract perspective leads to practical losses that can be optimized by gradient-based algorithms and includes previous objectives as special cases. At the same time, it allows us to consider divergences other than the reverse Kullback-Leibler divergence that is known to suffer from mode collapse. In particular, we propose the so-called log-variance loss, which exhibits favorable numerical properties and leads to significantly improved performance across all considered approaches.

Improved sampling via learned diffusions

TL;DR

The paper addresses sampling from unnormalized densities using diffusion-based methods and presents a unifying path-space framework that treats the sampling task as time-reversal of controlled diffusions. By introducing divergences between path-space measures, notably the log-variance loss, it overcomes limitations of reverse KL objectives such as mode collapse and high variance. The approach connects Schrödinger bridges, diffusion-based generative modeling, and time-reversed reference-process methods into a single formalism, providing practical, differentiable losses that can be optimized with gradient-based methods. Empirical results on GMM, Funnel, and double-well benchmarks demonstrate that the log-variance loss yields improved stability, mode coverage, and sampling quality across DIS, PIS, and related methods, with strong performance in higher dimensions. Overall, the work offers a principled, scalable route to improved diffusion-based sampling for unnormalized targets and lays groundwork for problem-tailored divergences.

Abstract

Recently, a series of papers proposed deep learning-based approaches to sample from target distributions using controlled diffusion processes, being trained only on the unnormalized target densities without access to samples. Building on previous work, we identify these approaches as special cases of a generalized Schrödinger bridge problem, seeking a stochastic evolution between a given prior distribution and the specified target. We further generalize this framework by introducing a variational formulation based on divergences between path space measures of time-reversed diffusion processes. This abstract perspective leads to practical losses that can be optimized by gradient-based algorithms and includes previous objectives as special cases. At the same time, it allows us to consider divergences other than the reverse Kullback-Leibler divergence that is known to suffer from mode collapse. In particular, we propose the so-called log-variance loss, which exhibits favorable numerical properties and leads to significantly improved performance across all considered approaches.
Paper Structure (34 sections, 6 theorems, 99 equations, 9 figures, 6 tables, 1 algorithm)

This paper contains 34 sections, 6 theorems, 99 equations, 9 figures, 6 tables, 1 algorithm.

Key Result

Lemma 2.1

The time-reversed process ${ \hbox{{\cr \hidewidth\reflectbox{$\m@th\vec{}\mkern4mu$}\hidewidth\cr {} $\m@th Y$\cr }}}^v$, given by the SDE satisfies that $p_{ { \hbox{{\cr \hidewidth\reflectbox{$\m@th\vec{}\mkern4mu$}\hidewidth\cr {} $\m@th Y$\cr }}}^v} = { \hbox{{\cr \hidewidth\reflectbox{$\m@th\vec{}\mkern4mu$}\hidewidth\cr {} $\m@th p$\cr }}}_{Y^v}$.

Figures (9)

  • Figure 1: Improved convergence of our proposed log-variance loss for a double well problem, see \ref{['sec: numerical experiments']} for further details.
  • Figure 2: KDE plots of (1) samples from the groundtruth distribution, (2 & 3) PIS with KL divergence and log-variance loss, and (4 & 5) DIS with KL divergence and log-variance loss for the GMM problem (from left to right). One can see that the log-variance loss does not suffer from mode collapse such as the reverse KL divergence, which only recovers the mode of $p_\mathrm{prior}=\mathcal{N}(0,\mathrm{I})$.
  • Figure 3: Marginals of the first coordinate of samples from PIS and DIS (left and right) for the DW problem with $d=5$, $m=5$, $\delta=4$. Again, one observes the mode coverage of the log-variance loss as compared to the reverse KL divergence. Similar behavior can also be observed for the other marginals (see \ref{['fig:dw_all']}) and higher-dimensional settings (see \ref{['fig:high_dim_dw']} for an example in $d=1000$).
  • Figure 4: Contour plots of a Gaussian mixture model $p_\mathrm{target}$ with $40$ modes analogous to the problem proposed in midgley2023flow. We plot samples of the prior $p_{\mathrm{prior}}=\mathcal{N}(0,\mathrm{I})$ (left) and the DIS method trained with the KL-based loss, the log-variance loss, and partial trajectory optimization (from left to right), see \ref{['app:subtraj']}. For all methods, we use $T=2$ to guarantee $p_{Y_T^0} \approx p_{\mathrm{prior}}$. Using the setting from \ref{['tab:subtraj']}, subtrajectory training can recover all modes without gradient information from the target, whereas other methods suffer from mode collapse---despite making use of $\nabla \log \rho$. midgley2023flow report mode collapse on this benchmark for several state-of-the-art methods. We remark that LV-DIS (unlike KL-DIS) recovers all modes when slightly increasing the prior variance.
  • Figure 5: We compare the standard deviations of the loss and (average) gradient estimators using either the KL-based loss or the log-variance loss. Each standard deviation is computed over $40$ simulations of the loss without updating the parameters. We show results for the DIS method on the $5$-dimensional DW target. As predicted by our theory, the log-variance loss exhibits significantly smaller standard deviations for both the loss and its gradient.
  • ...and 4 more figures

Theorems & Definitions (17)

  • Lemma 2.1: Time-reversed SDEs
  • proof
  • Proposition 2.3: Likelihood of path measures
  • proof
  • Definition 2.4: Log-variance divergence
  • Proposition 2.5: Robustness at the solution
  • proof
  • Lemma 3.1: Likelihood w.r.t. reference process
  • proof
  • proof : Proof of \ref{['prop: log-likelihood path measures']}
  • ...and 7 more