Self-Speculative Masked Diffusions
Andrew Campbell, Valentin De Bortoli, Jiaxin Shi, Arnaud Doucet
TL;DR
Self-Speculative Masked Diffusions address the high compute cost of discrete-data diffusion by replacing factorized token predictions with a non-factorized, model-wide distribution accessed through a hybrid non-causal/causal transformer. The method uses speculative sampling to draft tokens with a lightweight head and validate them with a stronger target, yielding a tractable log-likelihood bound and an $O(D^2)$ computation for the joint distribution. Empirically, SSMD achieves roughly a 2x reduction in function evaluations on text and protein data while preserving sample quality, demonstrated on Text8, OpenWebText, and UniRef50. This approach enables faster, scalable discrete-data generation and can potentially pair with other inference-speedups to extend to larger models.
Abstract
We present self-speculative masked diffusions, a new class of masked diffusion generative models for discrete data that require significantly fewer function evaluations to generate samples. Standard masked diffusion models predict factorized logits over currently masked positions. A number of masked positions are then sampled, however, the factorization approximation means that sampling too many positions in one go leads to poor sample quality. As a result, many simulation steps and therefore neural network function evaluations are required to generate high-quality data. We reduce the computational burden by generating non-factorized predictions over masked positions. This is achieved by modifying the final transformer attention mask from non-causal to causal, enabling draft token generation and parallel validation via a novel, model-integrated speculative sampling mechanism. This results in a non-factorized predictive distribution over masked positions in a single forward pass. We apply our method to GPT2 scale text modelling and protein sequences generation, finding that we can achieve a ~2x reduction in the required number of network forward passes relative to standard masked diffusion models.
