Table of Contents
Fetching ...

Beyond Scores: Proximal Diffusion Models

Zhenghan Fang, Mateo Díaz, Sam Buchanan, Jeremias Sulam

TL;DR

ProxDM proposes a principled alternative to score-based diffusion by using backward discretization with learned proximal operators to sample from data distributions. The method replaces MMSE denoisers with MAP proximal updates learned via proximal matching, enabling faster sampling with fewer steps. The authors provide convergence theory showing KL guarantees: fully backward ProxDM achieves $\widetilde{O}(d/\sqrt{\varepsilon})$ steps, while a hybrid variant achieves $\widetilde{O}(d/\varepsilon)$ steps. Empirically, ProxDM demonstrates significant speedups on MNIST, CIFAR-10, and CelebA-HQ while matching or surpassing score-based baselines, with the hybrid variant performing best in practice. Limitations include reliance on approximate proximal operators and regularity assumptions; future work aims to extend to other SDEs/ODEs and refine training schedules.

Abstract

Diffusion models have quickly become some of the most popular and powerful generative models for high-dimensional data. The key insight that enabled their development was the realization that access to the score -- the gradient of the log-density at different noise levels -- allows for sampling from data distributions by solving a reverse-time stochastic differential equation (SDE) via forward discretization, and that popular denoisers allow for unbiased estimators of this score. In this paper, we demonstrate that an alternative, backward discretization of these SDEs, using proximal maps in place of the score, leads to theoretical and practical benefits. We leverage recent results in proximal matching to learn proximal operators of the log-density and, with them, develop Proximal Diffusion Models (ProxDM). Theoretically, we prove that $\widetilde{O}(d/\sqrt{\varepsilon})$ steps suffice for the resulting discretization to generate an $\varepsilon$-accurate distribution w.r.t. the KL divergence. Empirically, we show that two variants of ProxDM achieve significantly faster convergence within just a few sampling steps compared to conventional score-matching methods.

Beyond Scores: Proximal Diffusion Models

TL;DR

ProxDM proposes a principled alternative to score-based diffusion by using backward discretization with learned proximal operators to sample from data distributions. The method replaces MMSE denoisers with MAP proximal updates learned via proximal matching, enabling faster sampling with fewer steps. The authors provide convergence theory showing KL guarantees: fully backward ProxDM achieves steps, while a hybrid variant achieves steps. Empirically, ProxDM demonstrates significant speedups on MNIST, CIFAR-10, and CelebA-HQ while matching or surpassing score-based baselines, with the hybrid variant performing best in practice. Limitations include reliance on approximate proximal operators and regularity assumptions; future work aims to extend to other SDEs/ODEs and refine training schedules.

Abstract

Diffusion models have quickly become some of the most popular and powerful generative models for high-dimensional data. The key insight that enabled their development was the realization that access to the score -- the gradient of the log-density at different noise levels -- allows for sampling from data distributions by solving a reverse-time stochastic differential equation (SDE) via forward discretization, and that popular denoisers allow for unbiased estimators of this score. In this paper, we demonstrate that an alternative, backward discretization of these SDEs, using proximal maps in place of the score, leads to theoretical and practical benefits. We leverage recent results in proximal matching to learn proximal operators of the log-density and, with them, develop Proximal Diffusion Models (ProxDM). Theoretically, we prove that steps suffice for the resulting discretization to generate an -accurate distribution w.r.t. the KL divergence. Empirically, we show that two variants of ProxDM achieve significantly faster convergence within just a few sampling steps compared to conventional score-matching methods.

Paper Structure

This paper contains 57 sections, 24 theorems, 149 equations, 12 figures, 2 algorithms.

Key Result

Theorem 1

Suppose that $\mathbb{E}_{p_0}\|X\|^2$ is bounded above by $M_2<\infty$. Further, assume for all $t \geq 0$, the potential $\ln p_t$ is a three-times differentiable function with $L$-Lipschitz gradient and $H$-Lipschitz Hessian, the moment $\mathbb{E}_{p_t}\|\nabla \ln p_t(X)\|^2$ is bounded above b

Figures (12)

  • Figure 1: Sampling from a distribution supported on a discrete set of points in 2D using exact score and proximal operators. Left: the target distribution and samples generated by various samplers using 5 sampling steps; right: Wasserstein-2 distance between sample and target across varying step numbers (NFE, number of function evaluations). Standard Euler-Maruyama without denoising fails when using only 5 steps. Score-based sampler with a denoising step at the end, as is common song2020score, requires tuning $\epsilon$: the step size for the last denoising step. Both PDA variants perform well without additional hyperparameters.
  • Figure 2: Left: MNIST samples generated by the score SDE sampler, score ODE sampler, and our hybrid ProxDM. Right: the resulting FID score as a function of the number of sampling steps.
  • Figure 3: FID vs. number of sampling steps (NFE) on CIFAR10. The dashed lines correspond to models trained specifically for 5, 10 and 20 steps, as opposed to the full-range models in solid lines.
  • Figure 4: CIFAR10 samples from score SDE, score ODE, and hybrid ProxDM samplers.
  • Figure 5: CelebA-HQ ($256\times256$) samples from score SDE, score ODE, and hybrid ProxDM samplers using 20 sampling steps.
  • ...and 7 more figures

Theorems & Definitions (36)

  • Theorem 1: Informal
  • Corollary 1
  • Lemma 1: Interpolating Process for Backward Algorithm
  • Lemma 2: Interpolating Process for Hybrid Algorithm
  • Theorem 2: Convergence guarantee
  • proof
  • Proposition 1
  • Lemma 3
  • Lemma 4: Lemma 9 in chen2022improved
  • Lemma 5: Time derivative of KL between SDEs
  • ...and 26 more