Table of Contents
Fetching ...

Continuously Augmented Discrete Diffusion model for Categorical Generative Modeling

Huangjie Zheng, Shansan Gong, Ruixiang Zhang, Tianrong Chen, Jiatao Gu, Mingyuan Zhou, Navdeep Jaitly, Yizhe Zhang

TL;DR

Standard discrete diffusion models lose information when tokens are replaced by an absorbing [MASK] state. CADD addresses this by augmenting the discrete diffusion with a paired continuous latent that retains semantic cues and guides denoising at masked positions, enabling a controllable trade-off between mode coverage and mode seeking. The approach yields consistent improvements in text, image, and code generation over strong discrete baselines, with scalable sampling and no architectural changes to existing backbones. Practically, CADD enhances generation quality while maintaining training efficiency, making it a versatile augmentation for categorical diffusion tasks.

Abstract

Standard discrete diffusion models treat all unobserved states identically by mapping them to an absorbing [MASK] token. This creates an 'information void' where semantic information that could be inferred from unmasked tokens is lost between denoising steps. We introduce Continuously Augmented Discrete Diffusion (CADD), a framework that augments the discrete state space with a paired diffusion in a continuous latent space. This yields graded, gradually corrupted states in which masked tokens are represented by noisy yet informative latent vectors rather than collapsed 'information voids'. At each reverse step, CADD may leverage the continuous latent as a semantic hint to guide discrete denoising. The design is clean and compatible with existing discrete diffusion training. At sampling time, the strength and choice of estimator for the continuous latent vector enables a controlled trade-off between mode-coverage (generating diverse outputs) and mode-seeking (generating contextually precise outputs) behaviors. Empirically, we demonstrate CADD improves generative quality over mask-based diffusion across text generation, image synthesis, and code modeling, with consistent gains on both qualitative and quantitative metrics against strong discrete baselines.

Continuously Augmented Discrete Diffusion model for Categorical Generative Modeling

TL;DR

Standard discrete diffusion models lose information when tokens are replaced by an absorbing [MASK] state. CADD addresses this by augmenting the discrete diffusion with a paired continuous latent that retains semantic cues and guides denoising at masked positions, enabling a controllable trade-off between mode coverage and mode seeking. The approach yields consistent improvements in text, image, and code generation over strong discrete baselines, with scalable sampling and no architectural changes to existing backbones. Practically, CADD enhances generation quality while maintaining training efficiency, making it a versatile augmentation for categorical diffusion tasks.

Abstract

Standard discrete diffusion models treat all unobserved states identically by mapping them to an absorbing [MASK] token. This creates an 'information void' where semantic information that could be inferred from unmasked tokens is lost between denoising steps. We introduce Continuously Augmented Discrete Diffusion (CADD), a framework that augments the discrete state space with a paired diffusion in a continuous latent space. This yields graded, gradually corrupted states in which masked tokens are represented by noisy yet informative latent vectors rather than collapsed 'information voids'. At each reverse step, CADD may leverage the continuous latent as a semantic hint to guide discrete denoising. The design is clean and compatible with existing discrete diffusion training. At sampling time, the strength and choice of estimator for the continuous latent vector enables a controlled trade-off between mode-coverage (generating diverse outputs) and mode-seeking (generating contextually precise outputs) behaviors. Empirically, we demonstrate CADD improves generative quality over mask-based diffusion across text generation, image synthesis, and code modeling, with consistent gains on both qualitative and quantitative metrics against strong discrete baselines.

Paper Structure

This paper contains 42 sections, 6 theorems, 46 equations, 11 figures, 8 tables, 2 algorithms.

Key Result

Proposition 1

The marginal at timestep $t$ can be factorized: Given $\alpha_t:=\prod_{s=1}^t(1-\beta_s)$ and $\overline {\bm{Q}}_t:=\prod_{s=1}^t {\bm{Q}}_s=\alpha_t {\bm{I}}+(1-\alpha_t)\,\mathbf 1\,\bm m^\top$ and $\bar{\gamma}_t:=\prod_{s=1}^t\gamma_s$, with ${\bm{z}}_0^i={\bm{w}}_\theta({\bm{x}}_0^i)$, the two terms factorized above represent the discrete

Figures (11)

  • Figure 1: (Best view in color) Comparison of diffusion models across modeling spaces. Masked diffusion uses [MASK] as noise and follows a single mask-to-token path, jumping from an absorbing state to token predictions. Continuous (Gaussian) diffusion evolves in the full embedding space, but intermediate latents often do not decode to valid tokens until the final step because the search space is large. CADD combines the stability of masked diffusion with the flexibility of continuous diffusion: discrete tokens anchor a context-consistent subspace, while the paired continuous latent allows smooth transitions among plausible token candidates, improving decoding at masked positions.
  • Figure 2: Example of Signal-to-Noise Ratio (SNR) change of one token in the forward of vanilla Mask Diffusion vs. CADD (Best view in color). After the second token is masked at the first time, CADD gradually corrupt the information of this token with Gaussian diffusion in the latents, resulting in a smooth decay.
  • Figure 3: (Best view in color) Illustrative depiction of CADD model, combining both the discrete and continuous feature of the data. In training, the clean token at the masked position will be created by embedding matrix and used to form the noisy embedding according to the continuous forward. In sampling, the model is able to predict a diverse distribution of possible tokens by sampling multiple ${\bm{z}}_{t}$. Then the predicted tokens will be recycled into the embedding matrix to form $\hat{{\bm{z}}}_{0, \theta}$ for the next iteration.
  • Figure 4: Unconditional text generative evaluation of model trained on OpenWebText (OWT) data. All method are evaluated with 128, 256, 512 1024, and 4096 sampling steps. MAUVE (Left Panel, higher is better) and generative perplexity (Right Panel, measured using GPT2-Large, lower is better) are reported.
  • Figure 5: Analogous figure of Figure \ref{['fig:owt-main']}. We compare the finetuned checkpoint using CADD objective with CADD and the initialization checkpoint of MDLM.
  • ...and 6 more figures

Theorems & Definitions (12)

  • Proposition 1: Timestep-$t$ joint marginal factorization
  • Proposition 2: Factorization of the true posterior
  • Lemma 1
  • Proposition 3: ELBO decomposition
  • proof : Proof of Proposition \ref{['prop:elbo']}
  • Lemma 2: Continuous marginal conditioned on $({\bm{x}}_t,{\bm{x}}_0)$
  • Lemma 3: Conditional independency between ${\bm{z}}_t$ and ${\bm{x}}_{t-1}$ given $({\bm{x}}_t, {\bm{x}}_0)$
  • proof : Proof of Lemma \ref{['lem:zt-marginal-vector']} and Lemma \ref{['lemma:cond-indp']}
  • proof : Proof of Proposition \ref{['prop:factorization-vector']}
  • proof : Proof of Proposition \ref{['prop:posterior-factorization']}
  • ...and 2 more