Table of Contents
Fetching ...

Generalized Discrete Diffusion with Self-Correction

Linxuan Wang, Ziyi Wang, Yikun Bai, Wei Deng, Guang Lin, Qifan Song

TL;DR

This work proposes a Self-Correcting Discrete Diffusion model to reformulate pretrained self-correction with explicit state transitions and learn directly in discrete time, which enables more efficient parallel decoding while preserving generation quality.

Abstract

Self-correction is an effective technique for maintaining parallel sampling in discrete diffusion models with minimal performance degradation. Prior work has explored self-correction at inference time or during post-training; however, such approaches often suffer from limited generalization and may impair reasoning performance. GIDD pioneers pretraining-based self-correction via a multi-step BERT-style uniform-absorbing objective. However, GIDD relies on a continuous interpolation-based pipeline with opaque interactions between uniform transitions and absorbing masks, which complicates hyperparameter tuning and hinders practical performance. In this work, we propose a Self-Correcting Discrete Diffusion (SCDD) model to reformulate pretrained self-correction with explicit state transitions and learn directly in discrete time. Our framework also simplifies the training noise schedule, eliminates a redundant remasking step, and relies exclusively on uniform transitions to learn self-correction. Experiments at the GPT-2 scale demonstrate that our method enables more efficient parallel decoding while preserving generation quality.

Generalized Discrete Diffusion with Self-Correction

TL;DR

This work proposes a Self-Correcting Discrete Diffusion model to reformulate pretrained self-correction with explicit state transitions and learn directly in discrete time, which enables more efficient parallel decoding while preserving generation quality.

Abstract

Self-correction is an effective technique for maintaining parallel sampling in discrete diffusion models with minimal performance degradation. Prior work has explored self-correction at inference time or during post-training; however, such approaches often suffer from limited generalization and may impair reasoning performance. GIDD pioneers pretraining-based self-correction via a multi-step BERT-style uniform-absorbing objective. However, GIDD relies on a continuous interpolation-based pipeline with opaque interactions between uniform transitions and absorbing masks, which complicates hyperparameter tuning and hinders practical performance. In this work, we propose a Self-Correcting Discrete Diffusion (SCDD) model to reformulate pretrained self-correction with explicit state transitions and learn directly in discrete time. Our framework also simplifies the training noise schedule, eliminates a redundant remasking step, and relies exclusively on uniform transitions to learn self-correction. Experiments at the GPT-2 scale demonstrate that our method enables more efficient parallel decoding while preserving generation quality.
Paper Structure (70 sections, 7 theorems, 117 equations, 3 figures, 6 tables)

This paper contains 70 sections, 7 theorems, 117 equations, 3 figures, 6 tables.

Key Result

Proposition 3.1

When $\rho_t, \gamma_t$ are monotonically decreasing, the following forward Markov transition kernel induces the marginal distribution (eq:marginal): where $t,s \in \{t_{-1},t_0,...,t_T\}$ are two adjacent time points satisfying $t>s$, $\rho_{t|s}:=\tfrac{\rho_t}{\rho_s}$, and $\gamma_{t|s}:=\tfrac{\gamma_t}{\gamma_s}$.

Figures (3)

  • Figure 1: Example of self-correction between two consecutive denoising steps (127→128). Generated by SCDD ($p_u=0.2$, trained on OWT) under 128 total denoising steps. Inappropriate tokens are directly corrected without remasking.
  • Figure 2: Correction Rate per Step versus total number of denoising steps at different maximum uniform noise ratios. Reported values are averaged from 128 independently generated sequences.
  • Figure 3: Cumulative correction ratio versus denoising progress evaluated by current step counts. At step $s$, it is defined as the fraction of all corrections that have occurred up to step $s$ over the entire generation process. Total number of denoising steps is 512. Reported values are averaged from 128 independently generated sequences.

Theorems & Definitions (18)

  • Proposition 3.1
  • proof
  • Lemma 3.2
  • proof
  • Remark 3.3
  • proof
  • Lemma B.1
  • proof
  • Proposition B.2
  • proof
  • ...and 8 more