Table of Contents
Fetching ...

Self-Refining Diffusion Samplers: Enabling Parallelization via Parareal Iterations

Nikil Roashan Selvam, Amil Merchant, Stefano Ermon

TL;DR

SRDS introduces Self-Refining Diffusion Samplers, a Parareal-inspired, parallel-in-time framework that refines diffusion trajectory estimates to yield high-quality samples with reduced latency. By coupling a fast 1-step coarse solver with parallelizable $\sqrt{N}$-step fine solves and predictor-corrector updates, SRDS guarantees convergence to the standard $N$-step solution while enabling batched inference and pipeline parallelism. Empirical results across pixel- and latent-diffusion models show substantial wallclock speedups (up to multiple-fold in longer trajectories) with preserved sample quality, at the expense of additional parallel compute and memory $\mathcal{O}(\sqrt{N})$. This modular approach offers a practical path to real-time diffusion-based applications and can integrate with a range of solvers and future multigrid strategies.

Abstract

In diffusion models, samples are generated through an iterative refinement process, requiring hundreds of sequential model evaluations. Several recent methods have introduced approximations (fewer discretization steps or distillation) to trade off speed at the cost of sample quality. In contrast, we introduce Self-Refining Diffusion Samplers (SRDS) that retain sample quality and can improve latency at the cost of additional parallel compute. We take inspiration from the Parareal algorithm, a popular numerical method for parallel-in-time integration of differential equations. In SRDS, a quick but rough estimate of a sample is first created and then iteratively refined in parallel through Parareal iterations. SRDS is not only guaranteed to accurately solve the ODE and converge to the serial solution but also benefits from parallelization across the diffusion trajectory, enabling batched inference and pipelining. As we demonstrate for pre-trained diffusion models, the early convergence of this refinement procedure drastically reduces the number of steps required to produce a sample, speeding up generation for instance by up to 1.7x on a 25-step StableDiffusion-v2 benchmark and up to 4.3x on longer trajectories.

Self-Refining Diffusion Samplers: Enabling Parallelization via Parareal Iterations

TL;DR

SRDS introduces Self-Refining Diffusion Samplers, a Parareal-inspired, parallel-in-time framework that refines diffusion trajectory estimates to yield high-quality samples with reduced latency. By coupling a fast 1-step coarse solver with parallelizable -step fine solves and predictor-corrector updates, SRDS guarantees convergence to the standard -step solution while enabling batched inference and pipeline parallelism. Empirical results across pixel- and latent-diffusion models show substantial wallclock speedups (up to multiple-fold in longer trajectories) with preserved sample quality, at the expense of additional parallel compute and memory . This modular approach offers a practical path to real-time diffusion-based applications and can integrate with a range of solvers and future multigrid strategies.

Abstract

In diffusion models, samples are generated through an iterative refinement process, requiring hundreds of sequential model evaluations. Several recent methods have introduced approximations (fewer discretization steps or distillation) to trade off speed at the cost of sample quality. In contrast, we introduce Self-Refining Diffusion Samplers (SRDS) that retain sample quality and can improve latency at the cost of additional parallel compute. We take inspiration from the Parareal algorithm, a popular numerical method for parallel-in-time integration of differential equations. In SRDS, a quick but rough estimate of a sample is first created and then iteratively refined in parallel through Parareal iterations. SRDS is not only guaranteed to accurately solve the ODE and converge to the serial solution but also benefits from parallelization across the diffusion trajectory, enabling batched inference and pipelining. As we demonstrate for pre-trained diffusion models, the early convergence of this refinement procedure drastically reduces the number of steps required to produce a sample, speeding up generation for instance by up to 1.7x on a 25-step StableDiffusion-v2 benchmark and up to 4.3x on longer trajectories.

Paper Structure

This paper contains 28 sections, 3 theorems, 9 equations, 8 figures, 8 tables, 1 algorithm.

Key Result

Proposition 1

The sample output by SRDS converges to the output of the $N$-step sequential solver in at most $\sqrt{N}$ refinement iterations.

Figures (8)

  • Figure 1: A visualization of the iterative refinement provided by the SRDS algorithm on a sample from StableDiffusion with the prompt 'a beautiful castle, matte painting.' The initial coarse solve (left) via limited steps provides a rough estimate of the sample, which iteratively refined through iterations of our algorithm. Early convergence is observed as the 3rd output nearly matches, a key feature that enables efficient generation.
  • Figure 2: First iteration of the parareal algorithm to solve an example ODE. The black curve represents the desired solution from the fine solver. The magenta dots indicate the running solution after one iteration of predictor-corrector updates. Figure inspiration from pentland2022stochastic.
  • Figure 3: Computation graph for diffusion sampling. For SRDS, the red arrows correspond to fine solves, and each block --- $[0,\sqrt{N}], [\sqrt{N}, 2\sqrt{N}]$, and so on --- can be perform the fine solves independently in parallel. The blue arrows correspond to 1-step coarse solves.
  • Figure 4: Illustration of the pipelined version of the SRDS algorithm on $N=16$ denoising steps, which results in a direct 2x speedup compared to the vanilla version.
  • Figure 5: Convergence of the SRDS algorithm for a trajectory of length 25 (left) and 100 (right) showcase how early termination of the algorithm can yield equivalent sample quality. In particular, longer trajectories with increased parallelism appear to converge faster.
  • ...and 3 more figures

Theorems & Definitions (7)

  • Proposition 1
  • Proposition 2
  • Proposition 3
  • proof
  • proof
  • proof
  • proof