End-To-End Learning of Gaussian Mixture Priors for Diffusion Sampler
Denis Blessing, Xiaogang Jia, Gerhard Neumann
TL;DR
This work tackles diffusion-model-based sampling from unnormalized densities by addressing unknown target support, discretization errors, and mode collapse through end-to-end learned Gaussian mixture priors (GMPs). It develops a principled VI framework with forward/backward diffusion processes and extended ELBO, enabling gradient-based learning of priors and time horizons. A key contribution is iterative model refinement (IMR), which progressively adds Gaussian components to the prior to improve exploration and multi-modal coverage, supported by a practical initialization heuristic. Empirical results across real-world Bayesian tasks and synthetic benchmarks show GMPs consistently improve ELBO, log-normalizer estimation, and transport/mode-coverage metrics, often surpassing state-of-the-art methods, with IMR enabling robust multi-modal sampling in high dimensions. Overall, GMPs provide a flexible, scalable path to more accurate and expressive diffusion samplers for complex targets.
Abstract
Diffusion models optimized via variational inference (VI) have emerged as a promising tool for generating samples from unnormalized target densities. These models create samples by simulating a stochastic differential equation, starting from a simple, tractable prior, typically a Gaussian distribution. However, when the support of this prior differs greatly from that of the target distribution, diffusion models often struggle to explore effectively or suffer from large discretization errors. Moreover, learning the prior distribution can lead to mode-collapse, exacerbated by the mode-seeking nature of reverse Kullback-Leibler divergence commonly used in VI. To address these challenges, we propose end-to-end learnable Gaussian mixture priors (GMPs). GMPs offer improved control over exploration, adaptability to target support, and increased expressiveness to counteract mode collapse. We further leverage the structure of mixture models by proposing a strategy to iteratively refine the model by adding mixture components during training. Our experimental results demonstrate significant performance improvements across a diverse range of real-world and synthetic benchmark problems when using GMPs without requiring additional target evaluations.
