Table of Contents
Fetching ...

Learning to Discretize Denoising Diffusion ODEs

Vinh Tong, Hoang Trung-Dung, Anji Liu, Guy Van den Broeck, Mathias Niepert

TL;DR

This paper tackles the high computational cost of sampling from pre-trained diffusion probabilistic models by introducing LD3, a lightweight framework that learns optimal time discretization for diffusion ODE solvers. LD3 directly minimizes global truncation error via a teacher-student scheme, enhanced with a soft teacher-forcing surrogate and a memory-efficient training strategy, enabling effective discretization with minimal training overhead. Empirically, LD3 improves sample quality across multiple datasets and solvers, especially at low NFEs, and achieves competitive or superior FID scores with modest training time on a single GPU. The approach is solver-aware and model-agnostic, offering a practical route to faster, high-quality diffusion-based generation without retraining large denoising networks.

Abstract

Diffusion Probabilistic Models (DPMs) are generative models showing competitive performance in various domains, including image synthesis and 3D point cloud generation. Sampling from pre-trained DPMs involves multiple neural function evaluations (NFEs) to transform Gaussian noise samples into images, resulting in higher computational costs compared to single-step generative models such as GANs or VAEs. Therefore, reducing the number of NFEs while preserving generation quality is crucial. To address this, we propose LD3, a lightweight framework designed to learn the optimal time discretization for sampling. LD3 can be combined with various samplers and consistently improves generation quality without having to retrain resource-intensive neural networks. We demonstrate analytically and empirically that LD3 improves sampling efficiency with much less computational overhead. We evaluate our method with extensive experiments on 7 pre-trained models, covering unconditional and conditional sampling in both pixel-space and latent-space DPMs. We achieve FIDs of 2.38 (10 NFE), and 2.27 (10 NFE) on unconditional CIFAR10 and AFHQv2 in 5-10 minutes of training. LD3 offers an efficient approach to sampling from pre-trained diffusion models. Code is available at https://github.com/vinhsuhi/LD3.

Learning to Discretize Denoising Diffusion ODEs

TL;DR

This paper tackles the high computational cost of sampling from pre-trained diffusion probabilistic models by introducing LD3, a lightweight framework that learns optimal time discretization for diffusion ODE solvers. LD3 directly minimizes global truncation error via a teacher-student scheme, enhanced with a soft teacher-forcing surrogate and a memory-efficient training strategy, enabling effective discretization with minimal training overhead. Empirically, LD3 improves sample quality across multiple datasets and solvers, especially at low NFEs, and achieves competitive or superior FID scores with modest training time on a single GPU. The approach is solver-aware and model-agnostic, offering a practical route to faster, high-quality diffusion-based generation without retraining large denoising networks.

Abstract

Diffusion Probabilistic Models (DPMs) are generative models showing competitive performance in various domains, including image synthesis and 3D point cloud generation. Sampling from pre-trained DPMs involves multiple neural function evaluations (NFEs) to transform Gaussian noise samples into images, resulting in higher computational costs compared to single-step generative models such as GANs or VAEs. Therefore, reducing the number of NFEs while preserving generation quality is crucial. To address this, we propose LD3, a lightweight framework designed to learn the optimal time discretization for sampling. LD3 can be combined with various samplers and consistently improves generation quality without having to retrain resource-intensive neural networks. We demonstrate analytically and empirically that LD3 improves sampling efficiency with much less computational overhead. We evaluate our method with extensive experiments on 7 pre-trained models, covering unconditional and conditional sampling in both pixel-space and latent-space DPMs. We achieve FIDs of 2.38 (10 NFE), and 2.27 (10 NFE) on unconditional CIFAR10 and AFHQv2 in 5-10 minutes of training. LD3 offers an efficient approach to sampling from pre-trained diffusion models. Code is available at https://github.com/vinhsuhi/LD3.
Paper Structure (48 sections, 2 theorems, 26 equations, 23 figures, 22 tables, 1 algorithm)

This paper contains 48 sections, 2 theorems, 26 equations, 23 figures, 22 tables, 1 algorithm.

Key Result

Theorem 1

Let $\Psi_*$ and ${\Psi}_{\bm{\xi}}$ be a teacher and student ODE solver each with noise distribution $\mathcal{N}(\mathbf{0}, \sigma_T^2\mathbf{I}) \in \mathbb{R}^d$, and with, respectively, distributions $q$ and $p_{\bm{\xi}}$. Assume both $\Psi_*$ and ${\Psi}_{\bm{\xi}}$ are invertible. Let $r>0$ where $C(\Psi_{\bm{\xi}^*}(\mathbf{x})) = \log|\det J_{\Psi_{\bm{\xi}^*}}(\Psi_{\bm{\xi}^*}^{-1}(\m

Figures (23)

  • Figure 1: Motivation and elaboration of LD3. (a) Directly optimizing the global truncation error loss $\mathcal{L}_{\mathrm{hard}}$ by minimizing the teacher and student outputs improves sample quality. (b) The surrogate objective $\mathcal{L}_{\mathrm{soft}}$ that allows discrepancies in the initial condition (i.e., $\mathbf{x}_T$) between the teacher solver and the student solver is easier to optimize. (c) By optimizing the surrogate objective, LD3 learns better discretization strategies.
  • Figure 2: $\mathcal{L}_{\mathrm{soft}} (\bm{\xi})$ drops significantly as we increase $r$.
  • Figure 3: Side-by-side comparison of selected images generated with Stable Diffusion [iPNDM]. Left: NFE=6, Right: NFE=5.
  • Figure 4: Side-by-side comparison of random images generated by different pre-trained models across four datasets: AFHQv2, ImageNet, FFHQ, and LSUM-Bedroom. We compare LD3 with DMN xue2024accelerating, GITS chen2024trajectory, and Time Uniform discretization. For each dataset, samples from each column are created using the same initial noise, solver, and the number of NFE. We provide more side-by-side comparisons in \ref{['subapp:addition_samples']}
  • Figure 5: (a) Training time for LD3 and AYS sabour2024align on various NFE; (b) The effect of data size on model performance (FID is evaluated with 5K samples, DPM_Solver++(3M), NFE=4).
  • ...and 18 more figures

Theorems & Definitions (3)

  • Theorem 1
  • Theorem \ref{the:kl_bound}
  • proof