Table of Contents
Fetching ...

Rethinking Token Prediction: Tree-Structured Diffusion Language Model

Zihao Wu, Haoming Yang, Juncheng Dong, Vahid Tarokh

Abstract

Discrete diffusion language models have emerged as a competitive alternative to auto-regressive language models, but training them efficiently under limited parameter and memory budgets remains challenging. Modern architectures are predominantly based on a full-vocabulary token prediction layer, which accounts for a substantial fraction of model parameters (e.g., more than 20% in small scale DiT-style designs) and often dominates peak GPU memory usage. This leads to inefficient use of both parameters and memory under constrained training resources. To address this issue, we revisit the necessity of explicit full-vocabulary prediction, and instead exploit the inherent structure among tokens to build a tree-structured diffusion language model. Specifically, we model the diffusion process with intermediate latent states corresponding to a token's ancestor nodes in a pre-constructed vocabulary tree. This tree-structured factorization exponentially reduces the classification dimensionality, makes the prediction head negligible in size, and enables reallocation of parameters to deepen the attention blocks. Empirically, under the same parameter budget, our method reduces peak GPU memory usage by half while matching the perplexity performance of state-of-the-art discrete diffusion language models.

Rethinking Token Prediction: Tree-Structured Diffusion Language Model

Abstract

Discrete diffusion language models have emerged as a competitive alternative to auto-regressive language models, but training them efficiently under limited parameter and memory budgets remains challenging. Modern architectures are predominantly based on a full-vocabulary token prediction layer, which accounts for a substantial fraction of model parameters (e.g., more than 20% in small scale DiT-style designs) and often dominates peak GPU memory usage. This leads to inefficient use of both parameters and memory under constrained training resources. To address this issue, we revisit the necessity of explicit full-vocabulary prediction, and instead exploit the inherent structure among tokens to build a tree-structured diffusion language model. Specifically, we model the diffusion process with intermediate latent states corresponding to a token's ancestor nodes in a pre-constructed vocabulary tree. This tree-structured factorization exponentially reduces the classification dimensionality, makes the prediction head negligible in size, and enables reallocation of parameters to deepen the attention blocks. Empirically, under the same parameter budget, our method reduces peak GPU memory usage by half while matching the perplexity performance of state-of-the-art discrete diffusion language models.

Paper Structure

This paper contains 30 sections, 7 theorems, 46 equations, 6 figures, 4 tables, 2 algorithms.

Key Result

Lemma 3.1

The forward process $q_{t|0}(z_t|x)$ on $t\in [t_h,t_{h+1}]$ is equivalent to first mapping $x$ to its ancestor node at height $h$ and then diffusing within that level: $\blacktriangleleft$$\blacktriangleleft$

Figures (6)

  • Figure 1: Generation process (left): Our algorithm features in-level child prediction rather than standard token prediction, delivering substantial efficiency gains while achieving improved performance. Token tree (right): Child prediction is enabled by a principled token tree, in which each state evolves strictly between adjacent levels.
  • Figure 2: Memory effiency of TDLM. (Left) Throughput and peak memory comparison (denoted as size of the circle) across diffusion language models at different scales. (Right) Validation negative ELBO and average Flops/s of small scale models during training using full capacity of four 24 GB GPUs.
  • Figure 3: Left: Validation nagative ELBO of various branching factor $K$ and cluster size ratio, plotted over training iterations. Solid lines indicate varying $K$ and fixed cluster size ratio, while dashed lines indicate fixed $K$ but varying cluster size. Middle: each bar indicates the cumulative validation negative ELBO obtained from a different tree construction, while each stack indicates the contribution of a particular tree height. (top stack indicates root level) Right: raw validation negative ELBO of each height level under different tree constructions.
  • Figure 4: Smoothed validation negative ELBO of different weight schedules. The weight schedules are plotted in the subfigures (Left: exponential; Right: linear).
  • Figure 5: Smoothed validation ELBO for joint modeling with different branching factor $K$ and neighborhood length $L$.
  • ...and 1 more figures

Theorems & Definitions (11)

  • Lemma 3.1
  • Proposition 3.2
  • proof : Proof of Proposition \ref{['prop:cum matrix']}
  • Lemma 3.3: Adapted proposition H.4 in rutte2025generalized
  • Theorem 3.4: Closed Form In-Level CT-ELBO of TDLM
  • proof : Proof of Theorem \ref{['theo:in-level ELBO']}
  • Corollary 3.5: Closed-form cross-level ELBO of TDLM
  • Theorem A.1: Closed Form In-Level CT-ELBO of TDLM
  • proof
  • Proposition A.2
  • ...and 1 more