Multi-Token Prediction via Self-Distillation
John Kirchenbauer, Abhimanyu Hans, Brian Bartoldson, Micah Goldblum, Ashwinee Panda, Tom Goldstein
TL;DR
The paper tackles slow autoregressive decoding by turning a pretrained LM into a fast multi-token predictor through online, on-policy distillation guided by a frozen teacher. By predicting blocks of tokens and scoring them with a strong NTP critic, the approach learns coherent joint token sequences without modifying the base inference code. Empirical results on GSM8K and related benchmarks show 2–5x speedups with modest accuracy loss, and ablations identify key design choices (hard teacher, randomized k, causal masking) that maximize performance. The method offers a practical, training-driven complement to speculative decoding, with potential for broader deployment and future optimization in decoding acceleration.
Abstract
Existing techniques for accelerating language model inference, such as speculative decoding, require training auxiliary speculator models and building and deploying complex inference pipelines. We consider a new approach for converting a pretrained autoregressive language model from a slow single next token prediction model into a fast standalone multi-token prediction model using a simple online distillation objective. The final model retains the exact same implementation as the pretrained initial checkpoint and is deployable without the addition of any auxiliary verifier or other specialized inference code. On GSM8K, our method produces models that can decode more than $3\times$ faster on average at $<5\%$ drop in accuracy relative to single token decoding performance.
