Table of Contents
Fetching ...

Learned Reference-based Diffusion Sampling for multi-modal distributions

Maxence Noble, Louis Grenioux, Marylou Gabrié, Alain Oliviero Durmus

TL;DR

This work tackles the challenge of sampling from multimodal target densities when only unnormalized densities are available. It introduces Learned Reference-based Diffusion Sampling (LRDS), a two-step approach that first learns a reference diffusion model from high-density region samples and then trains a diffusion-based sampler guided by this reference. LRDS comes in two practical flavors, GMM-LRDS and EBM-LRDS, to accommodate a wide range of target geometries, and outperforms existing diffusion-based samplers on challenging multimodal distributions. The framework connects to Schrödinger bridge and Doob transform perspectives and shows promise for extending diffusion sampling to non-Euclidean spaces with learned references.

Abstract

Over the past few years, several approaches utilizing score-based diffusion have been proposed to sample from probability distributions, that is without having access to exact samples and relying solely on evaluations of unnormalized densities. The resulting samplers approximate the time-reversal of a noising diffusion process, bridging the target distribution to an easy-to-sample base distribution. In practice, the performance of these methods heavily depends on key hyperparameters that require ground truth samples to be accurately tuned. Our work aims to highlight and address this fundamental issue, focusing in particular on multi-modal distributions, which pose significant challenges for existing sampling methods. Building on existing approaches, we introduce Learned Reference-based Diffusion Sampler (LRDS), a methodology specifically designed to leverage prior knowledge on the location of the target modes in order to bypass the obstacle of hyperparameter tuning. LRDS proceeds in two steps by (i) learning a reference diffusion model on samples located in high-density space regions and tailored for multimodality, and (ii) using this reference model to foster the training of a diffusion-based sampler. We experimentally demonstrate that LRDS best exploits prior knowledge on the target distribution compared to competing algorithms on a variety of challenging distributions.

Learned Reference-based Diffusion Sampling for multi-modal distributions

TL;DR

This work tackles the challenge of sampling from multimodal target densities when only unnormalized densities are available. It introduces Learned Reference-based Diffusion Sampling (LRDS), a two-step approach that first learns a reference diffusion model from high-density region samples and then trains a diffusion-based sampler guided by this reference. LRDS comes in two practical flavors, GMM-LRDS and EBM-LRDS, to accommodate a wide range of target geometries, and outperforms existing diffusion-based samplers on challenging multimodal distributions. The framework connects to Schrödinger bridge and Doob transform perspectives and shows promise for extending diffusion sampling to non-Euclidean spaces with learned references.

Abstract

Over the past few years, several approaches utilizing score-based diffusion have been proposed to sample from probability distributions, that is without having access to exact samples and relying solely on evaluations of unnormalized densities. The resulting samplers approximate the time-reversal of a noising diffusion process, bridging the target distribution to an easy-to-sample base distribution. In practice, the performance of these methods heavily depends on key hyperparameters that require ground truth samples to be accurately tuned. Our work aims to highlight and address this fundamental issue, focusing in particular on multi-modal distributions, which pose significant challenges for existing sampling methods. Building on existing approaches, we introduce Learned Reference-based Diffusion Sampler (LRDS), a methodology specifically designed to leverage prior knowledge on the location of the target modes in order to bypass the obstacle of hyperparameter tuning. LRDS proceeds in two steps by (i) learning a reference diffusion model on samples located in high-density space regions and tailored for multimodality, and (ii) using this reference model to foster the training of a diffusion-based sampler. We experimentally demonstrate that LRDS best exploits prior knowledge on the target distribution compared to competing algorithms on a variety of challenging distributions.

Paper Structure

This paper contains 113 sections, 24 theorems, 123 equations, 28 figures, 12 tables.

Key Result

Proposition 1

Assume that $\mathbb{P}^\star_T=\mathbb{P}^{\text{ref}}_T=\pi^{\text{base}}$ and there exists ${\theta^{\star}} \in \Theta$ such that $g^{\theta^{\star}}_t=g_t$. Then, the loss defined in eq:obj-cont achieves optimal solution at ${\theta^{\star}}$ and, setting $\varrho= \log (\gamma^{\mathrm{ref}}/\

Figures (28)

  • Figure 1: Illustration of the decisive role of the reference distribution. Here, we target a $16$-dimensional Gaussian mixture with two modes, that have respective weights $w=2/3$ and $1-w=1/3$, and display the estimation error of $w$ with different methods. (Left): Results for LV-PIS and LV-DDS when varying the value of their hyperparameter $\sigma$ (which directly determines $\pi^{\text{ref}}$ as shown in \ref{['table:DDS-PIS-comparison']}). The green dotted line represents the optimal variance for Gaussian approximation of $\pi$, see \ref{['app:gaussian-fitting']} for related computations. (Right) Results for RDS in PBM and VP settings when setting $\pi^{\text{ref}}$ as a Gaussian mixture with the same modes as $\pi$, but $w$ is replaced by $w_{\text{ref}}$. Details on the design of this experiment are given in \ref{['app:details_target']}.
  • Figure 2: Comparison between G-LRDS and GMM-LRDS. Here, the target distribution is the same $16$-dimensional Gaussian mixture as in \ref{['fig:sigma_sensi_and_weight_sensi']} (top left), see \ref{['app:details_target']} for more details. For illustration purpose, projections along the first two coordinates are used. In each cell, the value of 'Mode weight' refers to the effective weight of the left mode. Reference samples (bottom left) are obtained by running MALA sampler initialized in both target modes: each color (orange/purple) depicts one MALA Markov chain. In particular, none of them mixes between the modes. Running G-LRDS with an ML-estimated Gaussian reference (top middle) leads to mode collapse (bottom middle). Conversely, GMM-LRDS with an EM-estimated Gaussian mixture reference (top right) appropriately recovers the target distribution and the true mode proportions (bottom right).
  • Figure 3: Comparison between GMM-LRDS and EBM-LRDS in a multi-modal setting. Here, we target the $2$-dimensional Rings distribution, which has 3 unbalanced modes represented by the rings. (Left): Target density (top) and exact samples (bottom). (Middle): $16$-component GMM reference distribution (top) and resulting GMM-LRDS samples (bottom). (Right): EBM reference distribution (top) and resulting EBM-LRDS samples (bottom).
  • Figure 4: Estimation of the relative weight of $\phi^4$ modes with increasing $h$, averaged over $16$ sampling runs.
  • Figure 5: Samples obtained for Rings distribution. Reasonable results could not be obtained with PDDS due to numerical issues.
  • ...and 23 more figures

Theorems & Definitions (41)

  • Proposition 1
  • Lemma 1
  • proof
  • Lemma 2
  • proof
  • Lemma 3
  • proof
  • Lemma 4
  • proof
  • Lemma 5
  • ...and 31 more