Table of Contents
Fetching ...

Self-Speculative Masked Diffusions

Andrew Campbell, Valentin De Bortoli, Jiaxin Shi, Arnaud Doucet

TL;DR

Self-Speculative Masked Diffusions address the high compute cost of discrete-data diffusion by replacing factorized token predictions with a non-factorized, model-wide distribution accessed through a hybrid non-causal/causal transformer. The method uses speculative sampling to draft tokens with a lightweight head and validate them with a stronger target, yielding a tractable log-likelihood bound and an $O(D^2)$ computation for the joint distribution. Empirically, SSMD achieves roughly a 2x reduction in function evaluations on text and protein data while preserving sample quality, demonstrated on Text8, OpenWebText, and UniRef50. This approach enables faster, scalable discrete-data generation and can potentially pair with other inference-speedups to extend to larger models.

Abstract

We present self-speculative masked diffusions, a new class of masked diffusion generative models for discrete data that require significantly fewer function evaluations to generate samples. Standard masked diffusion models predict factorized logits over currently masked positions. A number of masked positions are then sampled, however, the factorization approximation means that sampling too many positions in one go leads to poor sample quality. As a result, many simulation steps and therefore neural network function evaluations are required to generate high-quality data. We reduce the computational burden by generating non-factorized predictions over masked positions. This is achieved by modifying the final transformer attention mask from non-causal to causal, enabling draft token generation and parallel validation via a novel, model-integrated speculative sampling mechanism. This results in a non-factorized predictive distribution over masked positions in a single forward pass. We apply our method to GPT2 scale text modelling and protein sequences generation, finding that we can achieve a ~2x reduction in the required number of network forward passes relative to standard masked diffusion models.

Self-Speculative Masked Diffusions

TL;DR

Self-Speculative Masked Diffusions address the high compute cost of discrete-data diffusion by replacing factorized token predictions with a non-factorized, model-wide distribution accessed through a hybrid non-causal/causal transformer. The method uses speculative sampling to draft tokens with a lightweight head and validate them with a stronger target, yielding a tractable log-likelihood bound and an computation for the joint distribution. Empirically, SSMD achieves roughly a 2x reduction in function evaluations on text and protein data while preserving sample quality, demonstrated on Text8, OpenWebText, and UniRef50. This approach enables faster, scalable discrete-data generation and can potentially pair with other inference-speedups to extend to larger models.

Abstract

We present self-speculative masked diffusions, a new class of masked diffusion generative models for discrete data that require significantly fewer function evaluations to generate samples. Standard masked diffusion models predict factorized logits over currently masked positions. A number of masked positions are then sampled, however, the factorization approximation means that sampling too many positions in one go leads to poor sample quality. As a result, many simulation steps and therefore neural network function evaluations are required to generate high-quality data. We reduce the computational burden by generating non-factorized predictions over masked positions. This is achieved by modifying the final transformer attention mask from non-causal to causal, enabling draft token generation and parallel validation via a novel, model-integrated speculative sampling mechanism. This results in a non-factorized predictive distribution over masked positions in a single forward pass. We apply our method to GPT2 scale text modelling and protein sequences generation, finding that we can achieve a ~2x reduction in the required number of network forward passes relative to standard masked diffusion models.

Paper Structure

This paper contains 27 sections, 3 theorems, 54 equations, 7 figures, 3 tables, 3 algorithms.

Key Result

Proposition 3.1

Consider the sampling scheme defined by Algorithm alg:mask_speculative_sampling. For a given ordering $\sigma$, let $A^{\sigma(d)}$ denote the event that the token in position $\sigma(d)$ was accepted and let $R^{\sigma(d)}$ denote the event it was rejected and resampled. The distribution of samples $p_{\theta, \phi}(\bm{x}^{\sigma(d+1:D)}, A^{\sigma(d+1:D)} | \bm{x}^{\sigma(1:d)}, R^{\sigma(d)})$

Figures (7)

  • Figure 1: Hybrid non-causal/causal transformer during training on the sentence "Speculation is like hazarding a guess" which is partially corrupted in the ordering given by $\sigma(1:6)=[6,5,2,4,3,1]$. We always mask the final elements in the sequence, here the final $3$ elements in the sequence, $[4,3,1]$. At inference time, the causal transformer uses draft tokens rather than real tokens.
  • Figure 2: Causal ($\overset{\rightarrow}{p}_{\theta, \phi}$) and non-causal ($\overset{\leftrightarrow}{p}_\theta$) training losses on the text8 dataset.
  • Figure 3: Spelling accuracy versus number of function evaluations (NFE) on the Text8 dataset for our speculative approach and mask diffusion.
  • Figure 4: pLDDTs versus NFE for mask diffusion and our speculative method. The mean pLDDT is computed over 512 samples with the standard error of the mean represented by the shading.
  • Figure 5: From left to right: The attention mechanism based on the example described in Figure \ref{['fig:speculative_arch']}. From left to right: any-to-any attention masks, standard left-to-right attention masks for AR models and a causal attention mask applied to the ordering from Figure \ref{['fig:speculative_arch']}. Rows correspond to the token making the query (attending). Columns correspond to the token providing the key (being attended to).
  • ...and 2 more figures

Theorems & Definitions (5)

  • Proposition 3.1
  • Lemma C.1
  • proof
  • Proposition C.2
  • proof