Table of Contents
Fetching ...

Multi-Token Prediction via Self-Distillation

John Kirchenbauer, Abhimanyu Hans, Brian Bartoldson, Micah Goldblum, Ashwinee Panda, Tom Goldstein

TL;DR

The paper tackles slow autoregressive decoding by turning a pretrained LM into a fast multi-token predictor through online, on-policy distillation guided by a frozen teacher. By predicting blocks of tokens and scoring them with a strong NTP critic, the approach learns coherent joint token sequences without modifying the base inference code. Empirical results on GSM8K and related benchmarks show 2–5x speedups with modest accuracy loss, and ablations identify key design choices (hard teacher, randomized k, causal masking) that maximize performance. The method offers a practical, training-driven complement to speculative decoding, with potential for broader deployment and future optimization in decoding acceleration.

Abstract

Existing techniques for accelerating language model inference, such as speculative decoding, require training auxiliary speculator models and building and deploying complex inference pipelines. We consider a new approach for converting a pretrained autoregressive language model from a slow single next token prediction model into a fast standalone multi-token prediction model using a simple online distillation objective. The final model retains the exact same implementation as the pretrained initial checkpoint and is deployable without the addition of any auxiliary verifier or other specialized inference code. On GSM8K, our method produces models that can decode more than $3\times$ faster on average at $<5\%$ drop in accuracy relative to single token decoding performance.

Multi-Token Prediction via Self-Distillation

TL;DR

The paper tackles slow autoregressive decoding by turning a pretrained LM into a fast multi-token predictor through online, on-policy distillation guided by a frozen teacher. By predicting blocks of tokens and scoring them with a strong NTP critic, the approach learns coherent joint token sequences without modifying the base inference code. Empirical results on GSM8K and related benchmarks show 2–5x speedups with modest accuracy loss, and ablations identify key design choices (hard teacher, randomized k, causal masking) that maximize performance. The method offers a practical, training-driven complement to speculative decoding, with potential for broader deployment and future optimization in decoding acceleration.

Abstract

Existing techniques for accelerating language model inference, such as speculative decoding, require training auxiliary speculator models and building and deploying complex inference pipelines. We consider a new approach for converting a pretrained autoregressive language model from a slow single next token prediction model into a fast standalone multi-token prediction model using a simple online distillation objective. The final model retains the exact same implementation as the pretrained initial checkpoint and is deployable without the addition of any auxiliary verifier or other specialized inference code. On GSM8K, our method produces models that can decode more than faster on average at drop in accuracy relative to single token decoding performance.
Paper Structure (46 sections, 8 equations, 11 figures, 5 tables)

This paper contains 46 sections, 8 equations, 11 figures, 5 tables.

Figures (11)

  • Figure 1: Example response to a GSM8K test question from our Qwen3-4B-Instruct-2507 based multi-token prediction model. Decoding is performed using a confidence-adaptive strategy with a threshold of 90%. Each colored block corresponds to a chunk of tokens produced during a single forward pass and is annotated with its size in tokens which ranges from 1 to 7 in this example. The average chunk size over the entire generation is 3.04.
  • Figure 2: Visual depiction of how a piece of training text is tokenized and masked, including replication of the $k$ ground truth tokens corresponding to each MTP region in which we are actually doing prediction and the corresponding position embedding adjustments required. In this example, the sequence length is 18, the $k$ value is 3, and the number of MTP regions is also 3. Note that while the targets row (TgtID) shown comprises ground truth tokens from the dataset, under our proposed online training objective, the targets at predicted positions (Pred) are based on the teacher model's feedback, not the ground truth data. The masking style and and input replication shown materializes many different MTP problems within a single sequence in parallel, increasing training efficiency.
  • Figure 3: Visualization of an attention mask with rolling offsets; offset 0 on the left and -2 on the right for a fixed $k=3$. Randomized offsets enable supervision on problems with many different prefix lengths during the same training run.
  • Figure 4: Visualization of an attention mask with variable $k$ masking showing $k=3$ and $k=5$ with a fixed offset of 0. Randomized $k$ values enable supervision on MTP problems of many different sizes during the same training run.
  • Figure 5: The performance of our (Left) L3.1-8B-Magpie based MTP LM and (Right) Qwen3-4B-Inst-2507 MTP LM evaluated on the GSM8K benchmark after $\sim$100k steps of training. Performance tradeoff is visualized by plotting the effective $k$ value or "Acceleration Factor" versus the Accuracy on the benchmark. More detailed plots showing accuracy and acceleration as a function of training steps for both models are provided in \ref{['fig:dynamics-flagship-l3-gsm', 'fig:dynamics-flagship-q3-gsm']}. We observe that the adaptive decoding strategies achieve pareto-optimal tradeoffs between generation speed and response quality for both models.
  • ...and 6 more figures