Table of Contents
Fetching ...

MDM-Prime-v2: Binary Encoding and Index Shuffling Enable Compute-optimal Scaling of Diffusion Language Models

Chen-Hao Chao, Wei-Fang Sun, Junwei Qua, Chun-Yi Lee, Rahul G. Krishnan

Abstract

Masked diffusion models (MDM) exhibit superior generalization when learned using a Partial masking scheme (Prime). This approach converts tokens into sub-tokens and models the diffusion process at the sub-token level. We identify two limitations of the MDM-Prime framework. First, we lack tools to guide the hyperparameter choice of the token granularity in the subtokenizer. Second, we find that the function form of the subtokenizer significantly degrades likelihood estimation when paired with commonly used Byte-Pair-Encoding (BPE) tokenizers. To address these limitations, we study the tightness of the variational bound in MDM-Prime and develop MDM-Prime-v2, a masked diffusion language model which incorporates Binary Encoding and Index Shuffling. Our scaling analysis reveals that MDM-Prime-v2 is 21.8$\times$ more compute-efficient than autoregressive models (ARM). In compute-optimal comparisons, MDM-Prime-v2 achieves 7.77 perplexity on OpenWebText, outperforming ARM (12.99), MDM (18.94), and MDM-Prime (13.41). When extending the model size to 1.1B parameters, our model further demonstrates superior zero-shot accuracy on various commonsense reasoning tasks.

MDM-Prime-v2: Binary Encoding and Index Shuffling Enable Compute-optimal Scaling of Diffusion Language Models

Abstract

Masked diffusion models (MDM) exhibit superior generalization when learned using a Partial masking scheme (Prime). This approach converts tokens into sub-tokens and models the diffusion process at the sub-token level. We identify two limitations of the MDM-Prime framework. First, we lack tools to guide the hyperparameter choice of the token granularity in the subtokenizer. Second, we find that the function form of the subtokenizer significantly degrades likelihood estimation when paired with commonly used Byte-Pair-Encoding (BPE) tokenizers. To address these limitations, we study the tightness of the variational bound in MDM-Prime and develop MDM-Prime-v2, a masked diffusion language model which incorporates Binary Encoding and Index Shuffling. Our scaling analysis reveals that MDM-Prime-v2 is 21.8 more compute-efficient than autoregressive models (ARM). In compute-optimal comparisons, MDM-Prime-v2 achieves 7.77 perplexity on OpenWebText, outperforming ARM (12.99), MDM (18.94), and MDM-Prime (13.41). When extending the model size to 1.1B parameters, our model further demonstrates superior zero-shot accuracy on various commonsense reasoning tasks.
Paper Structure (35 sections, 13 theorems, 47 equations, 25 figures, 11 tables)

This paper contains 35 sections, 13 theorems, 47 equations, 25 figures, 11 tables.

Key Result

Proposition 3.1

Let $p$ and $p_\ell$ denote the MDM and MDM-Prime models. Let $\ell_1,\ell_2$ be token granularities satisfying $1 < \ell_1 < \ell_2$ and $\frac{\ell_2}{\ell_1}\in\mathbb{N}$. The following inequalities hold:

Figures (25)

  • Figure 1: Overview of MDM, MDM-Prime, and the MDM-Prime-v2 enhancements. $\ell$ is the token granularity in base-$b$ encoding and $V$ is the vocabulary size. $y_0^{i,j}$ is a sub-token with positional indices $i$ and $j$. MDM-Prime-v2 incorporates two techniques: Technique 1 sets $\ell=\lceil\log_2 V\rceil$ for $\bm{f}_\ell$, which performs binary encoding. Technique 2 employs an index shuffling operation to encode high-entropy sub-tokens. The gray cell intensity in the lower section represents sub-token probability when $y_0^{i,j}=$0 or 1.
  • Figure 2: (a) Token probability and (b) cumulative distribution function (CDF) over token indices of the GPT-2 tokenizer radford2019language evaluated on the C4 dataset raffel2023exploringlimitstransferlearning. The left and right subplots of (a) and (b) show setups without and with index shuffling, respectively. The gray dashed line in (b) represents the CDF of a uniform token distribution.
  • Figure 3: Illustrative examples of the index shuffling operation proposed in Technique 2. In this plot, $V=8$, $\ell=3$ and $b=2$. Color intensity represents probability values, where darker blue denotes higher probability. (a) Marginal distributions: Marginalization is performed by summing entries corresponding to sub-token indices 0 (gray) and 1 (white), respectively. Sub-tokens encoded from standard token indices (i.e., 'w/o Shuff.') exhibit low entropy, while the shuffling operation results in sub-tokens with higher entropy. A detailed numerical example are provided in Fig. \ref{['fig:apx:shuffle_example_numerical']} in Appendix. (b) Conditional (predictive) distribution: The distribution $q_\text{data\_y}(\bm{y}_0|\bm{y}_t)$ exhibits higher certainty under index shuffling, which results in improved likelihood estimation.
  • Figure 4: The NLL of MDM ($\ell=1$) and MDM-Prime ($\ell>1$) under three setups: 'w/o Shuff.,' 'w/ Shuff. (25%)', and 'w/ Shuff.' $\lceil\log_2 V\rceil=16$ represents the maximum of $\ell$. The red dashed line represent the NLL of the compute-optimal ARM. All models are trained using $10^{19}$ FLOPs. The experiments are conducted on C4.
  • Figure 5: The (a) loss envelops, (b) isoFLOP curves, and (c) isoloss curves of ARM, MDM, and MDM-Prime-v2. In subplots (a), the number of parameters ($N$) ranges from 14M (purple) to 3.4B (yellow). In subplots (b), the compute budget ($C$) ranges from $3 \times 10^{18}$ FLOPs (blue) to $3 \times 10^{20}$ FLOPs (yellow). In subplots (c), the curves represent loss contours, the solid blue line denotes the efficient frontier, and the red dashed line represents the $2.89 \times 10^{20}$ FLOPs setup adopted in Section \ref{['sec:experiment:owt']}. Triangular markers represent the configuration used by sahoo2024simplifieddiff, while circular markers denote the compute-optimal setup.
  • ...and 20 more figures

Theorems & Definitions (24)

  • Proposition 3.1
  • Proposition 3.2
  • Proposition 3.3
  • Proposition 3.4
  • Theorem 1.1
  • proof
  • Lemma 1.2
  • proof
  • Lemma 1.3
  • proof
  • ...and 14 more