Table of Contents
Fetching ...

Proximal Diffusion Neural Sampler

Wei Guo, Jaemoo Choi, Yuchen Zhu, Molei Tao, Yongxin Chen

TL;DR

PDNS addresses the challenge of sampling from multimodal unnormalized targets by reframing diffusion-based neural samplers as stochastic optimal control on path measures and introducing proximal iterations to stabilize learning. The framework unifies continuous and discrete samplers, with proximal WDCE as a practical objective to mitigate mode collapse while preserving mode coverage. Theoretical results show that proximal updates geometrically interpolate between a reference and the target, ensuring convergence to the optimum; adaptive schedulers balance convergence speed and exploration. Empirical results across continuous benchmarks (synthetic energies and particle systems) and discrete targets (Ising/Potts, max-cut) demonstrate improved stability, coverage, and sampling quality over strong baselines. Overall, PDNS offers a robust, scalable approach to diffusion-based neural sampling with practical benefits for high-dimensional, multimodal distributions.

Abstract

The task of learning a diffusion-based neural sampler for drawing samples from an unnormalized target distribution can be viewed as a stochastic optimal control problem on path measures. However, the training of neural samplers can be challenging when the target distribution is multimodal with significant barriers separating the modes, potentially leading to mode collapse. We propose a framework named \textbf{Proximal Diffusion Neural Sampler (PDNS)} that addresses these challenges by tackling the stochastic optimal control problem via proximal point method on the space of path measures. PDNS decomposes the learning process into a series of simpler subproblems that create a path gradually approaching the desired distribution. This staged procedure traces a progressively refined path to the desired distribution and promotes thorough exploration across modes. For a practical and efficient realization, we instantiate each proximal step with a proximal weighted denoising cross-entropy (WDCE) objective. We demonstrate the effectiveness and robustness of PDNS through extensive experiments on both continuous and discrete sampling tasks, including challenging scenarios in molecular dynamics and statistical physics.

Proximal Diffusion Neural Sampler

TL;DR

PDNS addresses the challenge of sampling from multimodal unnormalized targets by reframing diffusion-based neural samplers as stochastic optimal control on path measures and introducing proximal iterations to stabilize learning. The framework unifies continuous and discrete samplers, with proximal WDCE as a practical objective to mitigate mode collapse while preserving mode coverage. Theoretical results show that proximal updates geometrically interpolate between a reference and the target, ensuring convergence to the optimum; adaptive schedulers balance convergence speed and exploration. Empirical results across continuous benchmarks (synthetic energies and particle systems) and discrete targets (Ising/Potts, max-cut) demonstrate improved stability, coverage, and sampling quality over strong baselines. Overall, PDNS offers a robust, scalable approach to diffusion-based neural sampling with practical benefits for high-dimensional, multimodal distributions.

Abstract

The task of learning a diffusion-based neural sampler for drawing samples from an unnormalized target distribution can be viewed as a stochastic optimal control problem on path measures. However, the training of neural samplers can be challenging when the target distribution is multimodal with significant barriers separating the modes, potentially leading to mode collapse. We propose a framework named \textbf{Proximal Diffusion Neural Sampler (PDNS)} that addresses these challenges by tackling the stochastic optimal control problem via proximal point method on the space of path measures. PDNS decomposes the learning process into a series of simpler subproblems that create a path gradually approaching the desired distribution. This staged procedure traces a progressively refined path to the desired distribution and promotes thorough exploration across modes. For a practical and efficient realization, we instantiate each proximal step with a proximal weighted denoising cross-entropy (WDCE) objective. We demonstrate the effectiveness and robustness of PDNS through extensive experiments on both continuous and discrete sampling tasks, including challenging scenarios in molecular dynamics and statistical physics.

Paper Structure

This paper contains 100 sections, 1 theorem, 72 equations, 12 figures, 5 tables, 3 algorithms.

Key Result

Proposition 3.1

1. The optimal solution of eq:prox_sub_problem is 2. Assume for all $k\ge1$, the subproblems are solved to optimality and let $\mathbb{P}^{\theta_0}\gets\mathbb{P}^\mathrm{ref}$. Denote $\mathbb{P}^k$ as the corresponding path measure $\mathbb{P}^{\theta_k^*}$, which satisfies $\mathbb{P}^k\propto(\mathbb{P}^{k-1})^{\frac{1}{\eta_k+1}}(\mathbb{P}^ This implies $\mathbb{P}^k$ converges to $\mathb

Figures (12)

  • Figure 1: An example of mode collapse in WDCE-based method where the target distribution is an Ising model under low temperature with two runs. (a) The magnetization (i.e., average spin) of the generated samples during training. The ground-truth value is $0$. We plot two independent runs, namely Run 1 and Run 2. (b, c) Visualization of the generated samples at the final training step for Run 1 and Run 2. (d) Visualization of the ground-truth samples.
  • Figure 2: Average 2-point correlations in both vertical and horizontal directions of samples from learned Ising and Potts models at different inverse temperatures.
  • Figure 3: Ablation studies on fixed $\gamma_k$ for all stages $k$ on MoS benchmark. We fix $\gamma_k = \frac{\eta_k}{1+\eta_k}$ to a constant for all stages $k$ and visualize the first four stages ($k=1,2,3,4$; left to right). Larger $\gamma_k$, i.e., weak proximal regularization, leads to rapid mode collapse whereas smaller $\gamma_k$ preserves multimodal coverage. These results are consistent with the analysis in \ref{['sec:pdns_limit_wce']}.
  • Figure 4: Ablation on proximal step size $\eta_k$ and the choice of the scheduler on MoS. We evaluate Sinkhorn ($\downarrow$) and MMD ($\downarrow$) across training epochs for multiple choices of proximal step size $\eta_k$ and scheduling policy. $\gamma$ denotes $\gamma_k := \frac{\eta_k}{1+\eta_k}$ over stage $k$. The legend entry "$\gamma=\mathrm{const}$" denotes runs with a fixed $\gamma_k$ for all stages $k$. Note that $0< \gamma_k \leq 1$; larger $\gamma_k$ weakens the proximal effect and approaches the non-proximal WDCE variant, whereas smaller $\gamma_k$ imposes stronger regularization.
  • Figure 5: Ablation on proximal step size $\eta_k$ and the choice of the scheduler on LJ-13. We monitor $\gamma$ and energy 2-Wasserstein distance ($E(\cdot) \mathcal{W}_2 (\downarrow)$) across training epochs for multiple choices of proximal step size $\eta_k$ and scheduling policy. $\gamma$ denotes $\gamma_k := \frac{\eta_k}{1+\eta_k}$ over stage $k$. The legend entry "$\gamma=\mathrm{const}$" denotes runs with a fixed $\gamma_k$ for all stages $k$. Note that $0< \gamma_k \leq 1$; larger $\gamma_k$ weakens the proximal effect and approaches the non-proximal WDCE variant, whereas smaller $\gamma_k$ imposes stronger regularization.
  • ...and 7 more figures

Theorems & Definitions (4)

  • Proposition 3.1
  • proof
  • proof
  • Remark