Table of Contents
Fetching ...

Faster Diffusion Sampling with Randomized Midpoints: Sequential and Parallel

Shivam Gupta, Linda Cai, Sitan Chen

TL;DR

<3-5 sentence high-level summary> This work advances diffusion-based sampling by introducing a randomized midpoint discretization of the probability flow ODE, enabling significantly improved dimension dependence and parallelization guarantees. Under standard smoothness and score-estimation assumptions, the sequential sampler achieves a near-optimal $ ilde{O}(d^{5/12})$ iteration complexity to reach total variation error $ ilde{O}( rac{1}{})$ (specifically $ ilde{O}(d^{5/12})$ in many regimes), outperforming the prior $ ilde{O}( rac{ \, ext{something}}{}} ight)$-style bounds. The paper also presents a parallel algorithm with $ ilde{O}( ext{log}^2 d)$ rounds, achieved via a fixed-point-collocation approach and parallelized corrector steps, marking the first provable guarantees for parallel diffusion-sampling. As a byproduct, a similar ilde{O}(d^{5/12}) bound is obtained for log-concave sampling in total variation, broadening the impact beyond diffusion models.

Abstract

Sampling algorithms play an important role in controlling the quality and runtime of diffusion model inference. In recent years, a number of works~\cite{chen2023sampling,chen2023ode,benton2023error,lee2022convergence} have proposed schemes for diffusion sampling with provable guarantees; these works show that for essentially any data distribution, one can approximately sample in polynomial time given a sufficiently accurate estimate of its score functions at different noise levels. In this work, we propose a new scheme inspired by Shen and Lee's randomized midpoint method for log-concave sampling~\cite{ShenL19}. We prove that this approach achieves the best known dimension dependence for sampling from arbitrary smooth distributions in total variation distance ($\widetilde O(d^{5/12})$ compared to $\widetilde O(\sqrt{d})$ from prior work). We also show that our algorithm can be parallelized to run in only $\widetilde O(\log^2 d)$ parallel rounds, constituting the first provable guarantees for parallel sampling with diffusion models. As a byproduct of our methods, for the well-studied problem of log-concave sampling in total variation distance, we give an algorithm and simple analysis achieving dimension dependence $\widetilde O(d^{5/12})$ compared to $\widetilde O(\sqrt{d})$ from prior work.

Faster Diffusion Sampling with Randomized Midpoints: Sequential and Parallel

TL;DR

<3-5 sentence high-level summary> This work advances diffusion-based sampling by introducing a randomized midpoint discretization of the probability flow ODE, enabling significantly improved dimension dependence and parallelization guarantees. Under standard smoothness and score-estimation assumptions, the sequential sampler achieves a near-optimal iteration complexity to reach total variation error (specifically in many regimes), outperforming the prior -style bounds. The paper also presents a parallel algorithm with rounds, achieved via a fixed-point-collocation approach and parallelized corrector steps, marking the first provable guarantees for parallel diffusion-sampling. As a byproduct, a similar ilde{O}(d^{5/12}) bound is obtained for log-concave sampling in total variation, broadening the impact beyond diffusion models.

Abstract

Sampling algorithms play an important role in controlling the quality and runtime of diffusion model inference. In recent years, a number of works~\cite{chen2023sampling,chen2023ode,benton2023error,lee2022convergence} have proposed schemes for diffusion sampling with provable guarantees; these works show that for essentially any data distribution, one can approximately sample in polynomial time given a sufficiently accurate estimate of its score functions at different noise levels. In this work, we propose a new scheme inspired by Shen and Lee's randomized midpoint method for log-concave sampling~\cite{ShenL19}. We prove that this approach achieves the best known dimension dependence for sampling from arbitrary smooth distributions in total variation distance ( compared to from prior work). We also show that our algorithm can be parallelized to run in only parallel rounds, constituting the first provable guarantees for parallel sampling with diffusion models. As a byproduct of our methods, for the well-studied problem of log-concave sampling in total variation distance, we give an algorithm and simple analysis achieving dimension dependence compared to from prior work.
Paper Structure (39 sections, 35 theorems, 137 equations, 11 algorithms)

This paper contains 39 sections, 35 theorems, 137 equations, 11 algorithms.

Key Result

Theorem 1.1

Suppose that the data distribution $q$ has bounded second moment, its score functions $\nabla \ln q_t$ along the forward process are $L$-Lipschitz, and we are given score estimates which are $L$-Lipschitz and $\widetilde{O}(\frac{\varepsilon}{d^{1/12}\sqrt{L}})$$\widetilde{O}(\cdot)$ hides polylogar

Theorems & Definitions (55)

  • Theorem 1.1: Informal, see Theorem \ref{['thm:main_sequential_formal']}
  • Theorem 1.2: Informal, see Theorem \ref{['thm:main_parallel_formal']}
  • Theorem 1.3: Informal, see Theorem \ref{['thm:log-concave']}
  • Lemma 3.1: Informal, see Lemma \ref{['lem:sequential_predictor_variance']} for formal statement
  • Lemma 3.2: Informal, see Lemma \ref{['lem:predictor_sequential_helper']} for formal statement
  • Lemma 3.3: Informal, see Lemma \ref{['claim:parallel_contraction']} for formal statement
  • Theorem 3.4: Informal, see Theorem \ref{['thm:parallel_corrector']}
  • Lemma A.1: Naive ODE Coupling
  • proof
  • Lemma A.2
  • ...and 45 more