Table of Contents
Fetching ...

Fast ODE-based Sampling for Diffusion Models in Around 5 Steps

Zhenyu Zhou, Defang Chen, Can Wang, Chun Chen

TL;DR

The paper tackles the slow sampling of diffusion models by reframing diffusion sampling as a PF-ODE and proposing AMED-Solver, a single-step solver that learns the mean-direction to reduce discretization error, plus AMED-Plugin to enhance existing solvers with minimal overhead. A key geometric insight is that sampling trajectories lie in a two-dimensional subspace, enabling a learned mean-direction approach that sustains sample quality at very low NFEs (around 5). The authors demonstrate strong, dataset-spanning improvements across CIFAR-10, ImageNet-64, LSUN Bedroom, and stable-diffusion checkpoints, with AMED achieving state-of-the-art results among solver-based methods and substantial gains when used as a plugin. The work provides a practical, lightweight path to fast, high-quality diffusion sampling and highlights the potential of geometry-informed solvers in generative modeling.

Abstract

Sampling from diffusion models can be treated as solving the corresponding ordinary differential equations (ODEs), with the aim of obtaining an accurate solution with as few number of function evaluations (NFE) as possible. Recently, various fast samplers utilizing higher-order ODE solvers have emerged and achieved better performance than the initial first-order one. However, these numerical methods inherently result in certain approximation errors, which significantly degrades sample quality with extremely small NFE (e.g., around 5). In contrast, based on the geometric observation that each sampling trajectory almost lies in a two-dimensional subspace embedded in the ambient space, we propose Approximate MEan-Direction Solver (AMED-Solver) that eliminates truncation errors by directly learning the mean direction for fast diffusion sampling. Besides, our method can be easily used as a plugin to further improve existing ODE-based samplers. Extensive experiments on image synthesis with the resolution ranging from 32 to 512 demonstrate the effectiveness of our method. With only 5 NFE, we achieve 6.61 FID on CIFAR-10, 10.74 FID on ImageNet 64$\times$64, and 13.20 FID on LSUN Bedroom. Our code is available at https://github.com/zju-pi/diff-sampler.

Fast ODE-based Sampling for Diffusion Models in Around 5 Steps

TL;DR

The paper tackles the slow sampling of diffusion models by reframing diffusion sampling as a PF-ODE and proposing AMED-Solver, a single-step solver that learns the mean-direction to reduce discretization error, plus AMED-Plugin to enhance existing solvers with minimal overhead. A key geometric insight is that sampling trajectories lie in a two-dimensional subspace, enabling a learned mean-direction approach that sustains sample quality at very low NFEs (around 5). The authors demonstrate strong, dataset-spanning improvements across CIFAR-10, ImageNet-64, LSUN Bedroom, and stable-diffusion checkpoints, with AMED achieving state-of-the-art results among solver-based methods and substantial gains when used as a plugin. The work provides a practical, lightweight path to fast, high-quality diffusion sampling and highlights the potential of geometry-informed solvers in generative modeling.

Abstract

Sampling from diffusion models can be treated as solving the corresponding ordinary differential equations (ODEs), with the aim of obtaining an accurate solution with as few number of function evaluations (NFE) as possible. Recently, various fast samplers utilizing higher-order ODE solvers have emerged and achieved better performance than the initial first-order one. However, these numerical methods inherently result in certain approximation errors, which significantly degrades sample quality with extremely small NFE (e.g., around 5). In contrast, based on the geometric observation that each sampling trajectory almost lies in a two-dimensional subspace embedded in the ambient space, we propose Approximate MEan-Direction Solver (AMED-Solver) that eliminates truncation errors by directly learning the mean direction for fast diffusion sampling. Besides, our method can be easily used as a plugin to further improve existing ODE-based samplers. Extensive experiments on image synthesis with the resolution ranging from 32 to 512 demonstrate the effectiveness of our method. With only 5 NFE, we achieve 6.61 FID on CIFAR-10, 10.74 FID on ImageNet 6464, and 13.20 FID on LSUN Bedroom. Our code is available at https://github.com/zju-pi/diff-sampler.
Paper Structure (26 sections, 2 theorems, 27 equations, 15 figures, 13 tables, 2 algorithms)

This paper contains 26 sections, 2 theorems, 27 equations, 15 figures, 13 tables, 2 algorithms.

Key Result

Lemma 1

Under assumption assump:assump1 and assump:assump4, let $g(\tau) = f(\tau) / \sqrt{d}$, then $\mathbf{z}_s$ concentrates at a thin shell with radius

Figures (15)

  • Figure 1: Synthesized images by Stable-Diffusion rombach2022ldm with a default classifier-free guidance scale 7.5 and a text prompt "A Corgi on the grass surrounded by a cluster of colorful balloons". Our method improves DPM-Solver++(2M) lu2022dpmpp in sample quality.
  • Figure 2: Comparison of various ODE solvers. Red dots depict the actual sampling step of different solvers. (a) DDIM solver song2021ddim applies Euler discretization on PF-ODEs. In every sampling step, it follows the gradient direction to give the solution for next time step. (b) Multi-step solvers liu2022pseudozhang2023deislu2022dpmppzhao2023unipc require current gradient and several records of history gradients and then follow the combination of these gradients to give the solution. (c) In generalized DPM-Solver-2 lu2022dpm, there is a hyper-parameter $r$ controlling the location of intermediate time step. $r=0.5$ recovers the default DPM-Solver-2 and $r=1$ recovers Heun's second method karras2022edm. The gradient for sampling step is given by the combination of gradients at intermediate and current time steps (see \ref{['tab:comparison']}). (d) Our proposed AMED-Solver seeks to find the intermediate time step and the scaling factor that gives nearly optimal gradient directing to the ground truth solution. This gradient used for sampling step is adaptively learned instead of the heuristic assigned as DPM-Solver-2.
  • Figure 3: The sample quality degradation of multi-step and single-step ODE solvers. The quality of images generated by single-step solvers, especially higher-order ones including DPM-Solver-2 lu2022dpm and EDM karras2022edm, rapidly decreases as NFE decreases, while our proposed AMED-Solver largely mitigates such degradation. Examples are from FFHQ 64$\times$64 karras2019style and ImageNet 64$\times$64 russakovsky2015ImageNet.
  • Figure 4: We perform PCA to each sampling trajectory $\{\mathbf{x}_{t}\}_{t=\epsilon}^{T}$. (a) These trajectories are projected into a 2D subspace spanned by the top 2 principal components to get $\{\tilde{\mathbf{x}}_{t}\}_{t=\epsilon}^T$ and the relative projection error is calculated as $\left \| \mathbf{x}_{t} - \tilde{\mathbf{x}}_{t} \right \|_2 / \left \| \mathbf{x}_{t} \right \|_2$. (b) We progressively increase the number of principal components and calculate the cumulative percent variance as $\mathrm{Var}(\{\tilde{\mathbf{x}}_{t}\}_{t=\epsilon}^T) / \mathrm{Var}(\{\mathbf{x}_{t}\}_{t=\epsilon}^T)$. The results are obtained by averaging 1k sampling trajectories using EDM solver karras2022edm with 80 NFE.
  • Figure 5: Effectiveness of searching the intermediate time steps. Given a time schedule $\Gamma = \{t_1=\epsilon, \cdots, t_N=T\}$ where $\epsilon=0.002, T=80, N=6$, we first generate a ground truth trajectory. For each ODE solver, we generate a baseline trajectory by performing evaluations at $s_n=\sqrt{t_nt_{n+1}}$, and a searched trajectory by a greedy grid search on $r_n$ which gives $s_n=t_n^{r_n} t_{n+1}^{1-{r_n}}$.
  • ...and 10 more figures

Theorems & Definitions (4)

  • Lemma 1
  • proof
  • Proposition 1
  • proof