Table of Contents
Fetching ...

Distillation of Discrete Diffusion through Dimensional Correlations

Satoshi Hayakawa, Yuhta Takida, Masaaki Imaizumi, Hiromi Wakaki, Yuki Mitsufuji

TL;DR

The paper tackles slow sampling in discrete diffusion by introducing Di4C, a framework that distills many-step, dimensionally independent denoisers into few-step models using a mixture denoiser to capture dimensional correlations. It establishes theoretical guarantees showing that multi-step product models can approximate data with error decreasing as 1/N, and that the proposed distillation and consistency losses bound the teacher-student distance, enabling effective compression when the student is expressive enough. Empirically, Di4C delivers 2x to 4x sampling speedups across CIFAR-10 pixel space, ImageNet masked diffusion, and OpenWebText language models while preserving or improving sample quality and diversity, with minimal latency overhead. The work provides a principled, scalable path to accelerating discrete diffusion for both vision and language domains, bridging theory and cross-domain applications in diffusion-based generation.

Abstract

Diffusion models have demonstrated exceptional performances in various fields of generative modeling, but suffer from slow sampling speed due to their iterative nature. While this issue is being addressed in continuous domains, discrete diffusion models face unique challenges, particularly in capturing dependencies between elements (e.g., pixel relationships in image, sequential dependencies in language) mainly due to the computational cost of processing high-dimensional joint distributions. In this paper, (i) we propose "mixture" models for discrete diffusion that are capable of treating dimensional correlations while remaining scalable, and (ii) we provide a set of loss functions for distilling the iterations of existing models. Two primary theoretical insights underpin our approach: First, conventional models with element-wise independence can well approximate the data distribution, but essentially require {\it many sampling steps}. Second, our loss functions enable the mixture models to distill such many-step conventional models into just a few steps by learning the dimensional correlations. Our experimental results show the effectiveness of the proposed method in distilling pretrained discrete diffusion models across image and language domains. The code used in the paper is available at https://github.com/sony/di4c .

Distillation of Discrete Diffusion through Dimensional Correlations

TL;DR

The paper tackles slow sampling in discrete diffusion by introducing Di4C, a framework that distills many-step, dimensionally independent denoisers into few-step models using a mixture denoiser to capture dimensional correlations. It establishes theoretical guarantees showing that multi-step product models can approximate data with error decreasing as 1/N, and that the proposed distillation and consistency losses bound the teacher-student distance, enabling effective compression when the student is expressive enough. Empirically, Di4C delivers 2x to 4x sampling speedups across CIFAR-10 pixel space, ImageNet masked diffusion, and OpenWebText language models while preserving or improving sample quality and diversity, with minimal latency overhead. The work provides a principled, scalable path to accelerating discrete diffusion for both vision and language domains, bridging theory and cross-domain applications in diffusion-based generation.

Abstract

Diffusion models have demonstrated exceptional performances in various fields of generative modeling, but suffer from slow sampling speed due to their iterative nature. While this issue is being addressed in continuous domains, discrete diffusion models face unique challenges, particularly in capturing dependencies between elements (e.g., pixel relationships in image, sequential dependencies in language) mainly due to the computational cost of processing high-dimensional joint distributions. In this paper, (i) we propose "mixture" models for discrete diffusion that are capable of treating dimensional correlations while remaining scalable, and (ii) we provide a set of loss functions for distilling the iterations of existing models. Two primary theoretical insights underpin our approach: First, conventional models with element-wise independence can well approximate the data distribution, but essentially require {\it many sampling steps}. Second, our loss functions enable the mixture models to distill such many-step conventional models into just a few steps by learning the dimensional correlations. Our experimental results show the effectiveness of the proposed method in distilling pretrained discrete diffusion models across image and language domains. The code used in the paper is available at https://github.com/sony/di4c .

Paper Structure

This paper contains 83 sections, 11 theorems, 118 equations, 8 figures, 8 tables.

Key Result

Proposition 1

For any probability distribution $p$ over $\mathcal{S}^D$, there exist a probability distribution $\pi$ and a family of product distributions $p(\bm{x};\lambda) = \prod_{d=1}^Dp^d(x^d;\lambda)$ indexed by $\lambda$ satisfying $p(\bm{x}) = \mathbb{E}_{\lambda\sim\pi}\!\left[p(\bm{x};\lambda)\right]$

Figures (8)

  • Figure 1: Illustration of dimensional correlations. ( Left) Distribution $p(x, y)$ is two-dimensional categorical distribution. $p^1(x)$ and $p^2(y)$ are its marginals. ( Center) Conventional denoiser in discrete diffusion uses product model, which is simply product of marginal distributions. It fails to approximate ground truth distribution. ( Right) Our mixture model is given by expectation of product model $p(x,y;\lambda)=p^1(x;\lambda)p^2(y;\lambda)$ for random $\lambda$. In figure, $\lambda$ takes $\alpha, \beta, \gamma$ in equal probabilities, and model reconstructs $p(x, y)$.
  • Figure 2: Illustration of how our loss functions work. Through $\mathcal{L}_\mathrm{distil}$ and $\mathcal{L}_\mathrm{consis}$, we distill multiple teacher denoising steps into fewer steps of student denoiser.
  • Figure 3: FID-IS curves of $4/8$-step teacher and $4$-step Di4C models on ImageNet $256\times256$ when varying CFG coefficients. Arrows connect experimental results (dots) at each CFG coefficient in ascending order.
  • Figure 4: Comparison of SDTT checkpoints deschenaux2024beyond and their Di4C distillations.
  • Figure 5: Comparison of generated samples in CIFAR-10 experiment.
  • ...and 3 more figures

Theorems & Definitions (19)

  • Proposition 1
  • Theorem 1: $N$-step analytical sampling approximates data, informal
  • proof : Proof (sketch)
  • Theorem 2: Di4C student approximates $N$-step teacher
  • Proposition 2: Pinsker's inequality
  • Proposition 3: cover2006elements
  • Proposition 4
  • Proposition 5
  • Theorem 3
  • Proposition 6
  • ...and 9 more