Table of Contents
Fetching ...

End-To-End Learning of Gaussian Mixture Priors for Diffusion Sampler

Denis Blessing, Xiaogang Jia, Gerhard Neumann

TL;DR

This work tackles diffusion-model-based sampling from unnormalized densities by addressing unknown target support, discretization errors, and mode collapse through end-to-end learned Gaussian mixture priors (GMPs). It develops a principled VI framework with forward/backward diffusion processes and extended ELBO, enabling gradient-based learning of priors and time horizons. A key contribution is iterative model refinement (IMR), which progressively adds Gaussian components to the prior to improve exploration and multi-modal coverage, supported by a practical initialization heuristic. Empirical results across real-world Bayesian tasks and synthetic benchmarks show GMPs consistently improve ELBO, log-normalizer estimation, and transport/mode-coverage metrics, often surpassing state-of-the-art methods, with IMR enabling robust multi-modal sampling in high dimensions. Overall, GMPs provide a flexible, scalable path to more accurate and expressive diffusion samplers for complex targets.

Abstract

Diffusion models optimized via variational inference (VI) have emerged as a promising tool for generating samples from unnormalized target densities. These models create samples by simulating a stochastic differential equation, starting from a simple, tractable prior, typically a Gaussian distribution. However, when the support of this prior differs greatly from that of the target distribution, diffusion models often struggle to explore effectively or suffer from large discretization errors. Moreover, learning the prior distribution can lead to mode-collapse, exacerbated by the mode-seeking nature of reverse Kullback-Leibler divergence commonly used in VI. To address these challenges, we propose end-to-end learnable Gaussian mixture priors (GMPs). GMPs offer improved control over exploration, adaptability to target support, and increased expressiveness to counteract mode collapse. We further leverage the structure of mixture models by proposing a strategy to iteratively refine the model by adding mixture components during training. Our experimental results demonstrate significant performance improvements across a diverse range of real-world and synthetic benchmark problems when using GMPs without requiring additional target evaluations.

End-To-End Learning of Gaussian Mixture Priors for Diffusion Sampler

TL;DR

This work tackles diffusion-model-based sampling from unnormalized densities by addressing unknown target support, discretization errors, and mode collapse through end-to-end learned Gaussian mixture priors (GMPs). It develops a principled VI framework with forward/backward diffusion processes and extended ELBO, enabling gradient-based learning of priors and time horizons. A key contribution is iterative model refinement (IMR), which progressively adds Gaussian components to the prior to improve exploration and multi-modal coverage, supported by a practical initialization heuristic. Empirical results across real-world Bayesian tasks and synthetic benchmarks show GMPs consistently improve ELBO, log-normalizer estimation, and transport/mode-coverage metrics, often surpassing state-of-the-art methods, with IMR enabling robust multi-modal sampling in high dimensions. Overall, GMPs provide a flexible, scalable path to more accurate and expressive diffusion samplers for complex targets.

Abstract

Diffusion models optimized via variational inference (VI) have emerged as a promising tool for generating samples from unnormalized target densities. These models create samples by simulating a stochastic differential equation, starting from a simple, tractable prior, typically a Gaussian distribution. However, when the support of this prior differs greatly from that of the target distribution, diffusion models often struggle to explore effectively or suffer from large discretization errors. Moreover, learning the prior distribution can lead to mode-collapse, exacerbated by the mode-seeking nature of reverse Kullback-Leibler divergence commonly used in VI. To address these challenges, we propose end-to-end learnable Gaussian mixture priors (GMPs). GMPs offer improved control over exploration, adaptability to target support, and increased expressiveness to counteract mode collapse. We further leverage the structure of mixture models by proposing a strategy to iteratively refine the model by adding mixture components during training. Our experimental results demonstrate significant performance improvements across a diverse range of real-world and synthetic benchmark problems when using GMPs without requiring additional target evaluations.

Paper Structure

This paper contains 43 sections, 1 theorem, 39 equations, 12 figures, 11 tables, 1 algorithm.

Key Result

Proposition 1

Let (eq:X_bwd) be a (uncontrolled) stochastic process as defined in (eq:X_fwd_bwd) with $v^{\gamma} = 0$, starting from $p_T=\pi$. For a time-independent drift, i.e., $f(x, t) = f(x)$, the stationary distribution $p^{\text{st}}(x)$ for which $\frac{\partial p_t(x_t)}{\partial t} = 0$ holds, is given with normalization constant $Z^{\text{st}}$.

Figures (12)

  • Figure 1: Illustration of challenges (C1-C3) associated with diffusion-based sampling methods and how learned Gaussian mixture priors address them (bottom right). Here, $\pi$ denotes the target distribution.
  • Figure 2: Diffusion-Based Sampling: The goal is to align two parameterized Markov Processes $\vec{p}^{\ \theta}$ and ${ \hbox{{\cr \hidewidth\reflectbox{$\m@th\vec{}\mkern4mu$}\hidewidth\cr {} $\m@th p$\cr }}}^{\ \gamma}$. The former starts at the prior $p_0$ and runs forward in time while the latter starts at the target $\pi$ and runs backward.
  • Figure 3: Left side: Results for Funnel target, averaged across four seeds. Evaluation criteria include evidence lower bound ELBO, importance-weighted errors for estimating the log-normalizing constant $\Delta \log \mathcal{Z}$, effective sample size ESS, Sinkhorn distance $\mathcal{W}^{\ \gamma}_{2}$. The best overall results are highlighted in bold, with category-specific best results underlined. Arrows ($\uparrow$, $\downarrow$) indicate whether higher or lower values are preferable, respectively. Blue and green shading indicate that the method uses learned Gaussian (GP) and Gaussian mixture priors (GMP), respectively. Red shading indicate competing state-of-the-art methods. Note that ESS cannot be computed due to the use of resampling schemes. Right side: Visualization of the first two dimensions of the Funnel target. Colored ellipses and circles denote standard deviations and means of the Gaussian components, respectively. Red dots illustrate samples of the model.
  • Figure 4: Left side: Results for Fashion target, averaged across four seeds. Evaluation criteria include evidence lower bound ELBO, importance-weighted errors for estimating the log-normalizing constant $\Delta \log \mathcal{Z}$, and Sinkhorn distance $\mathcal{W}^{\ \gamma}_{2}$. The best overall results are highlighted in bold, with category-specific best results underlined. Arrows ($\uparrow$, $\downarrow$) indicate whether higher or lower values are preferable, respectively. Blue and green shading indicate that the method uses learned Gaussian (GP) and Gaussian mixture priors (GMP), respectively. Orange shading indicates that the method uses iterative model refinement (IMR). Red shading indicate competing state-of-the-art methods. Right side: Visualization of the $d= 28 \times 28 = 784$ dimensional Fashion samples. Top left corner visualizes samples from the target distribution. Colored frames indicate samples from different components of the Gaussian mixture.
  • Figure 5: Effective sample size (ESS) of DIS-GMP for various real-world benchmark problems, averaged across four seeds. Here, $N$ denotes the number of discretization steps and $K$ the number of components in den Gaussian mixture.
  • ...and 7 more figures

Theorems & Definitions (1)

  • Proposition 1