Proximal Diffusion Neural Sampler
Wei Guo, Jaemoo Choi, Yuchen Zhu, Molei Tao, Yongxin Chen
TL;DR
PDNS addresses the challenge of sampling from multimodal unnormalized targets by reframing diffusion-based neural samplers as stochastic optimal control on path measures and introducing proximal iterations to stabilize learning. The framework unifies continuous and discrete samplers, with proximal WDCE as a practical objective to mitigate mode collapse while preserving mode coverage. Theoretical results show that proximal updates geometrically interpolate between a reference and the target, ensuring convergence to the optimum; adaptive schedulers balance convergence speed and exploration. Empirical results across continuous benchmarks (synthetic energies and particle systems) and discrete targets (Ising/Potts, max-cut) demonstrate improved stability, coverage, and sampling quality over strong baselines. Overall, PDNS offers a robust, scalable approach to diffusion-based neural sampling with practical benefits for high-dimensional, multimodal distributions.
Abstract
The task of learning a diffusion-based neural sampler for drawing samples from an unnormalized target distribution can be viewed as a stochastic optimal control problem on path measures. However, the training of neural samplers can be challenging when the target distribution is multimodal with significant barriers separating the modes, potentially leading to mode collapse. We propose a framework named \textbf{Proximal Diffusion Neural Sampler (PDNS)} that addresses these challenges by tackling the stochastic optimal control problem via proximal point method on the space of path measures. PDNS decomposes the learning process into a series of simpler subproblems that create a path gradually approaching the desired distribution. This staged procedure traces a progressively refined path to the desired distribution and promotes thorough exploration across modes. For a practical and efficient realization, we instantiate each proximal step with a proximal weighted denoising cross-entropy (WDCE) objective. We demonstrate the effectiveness and robustness of PDNS through extensive experiments on both continuous and discrete sampling tasks, including challenging scenarios in molecular dynamics and statistical physics.
