Table of Contents
Fetching ...

On multi-token prediction for efficient LLM inference

Somesh Mehra, Javier Alonso Garcia, Lukas Mauch

TL;DR

The paper investigates accelerating LLM inference through multi-token prediction (MTP) in models pretrained for next-token prediction (NTP). It demonstrates that MTP is achievable in NTP models by marginalizing over intermediate token probabilities, formalized as $p(\mathcal{X}_{t:t+K}|\mathcal{X}_{\le t}; \theta)$; performance improves with model scale and contextual determinism. It also shows that attaching MTP heads to a frozen backbone is hampered by early hidden-layer specialization, though joint training and strategies like weighted hidden states (WHS) can mitigate the gap without fully closing it. The findings suggest that marginalization-based MTP remains a strong baseline, with MTP pretraining or advanced adaptation methods offering avenues for stronger, more practical acceleration of inference.

Abstract

We systematically investigate multi-token prediction (MTP) capabilities within LLMs pre-trained for next-token prediction (NTP). We first show that such models inherently possess MTP capabilities via numerical marginalization over intermediate token probabilities, though performance is data-dependent and improves with model scale. Furthermore, we explore the challenges of integrating MTP heads into frozen LLMs and find that their hidden layers are strongly specialized for NTP, making adaptation non-trivial. Finally, we show that while joint training of MTP heads with the backbone improves performance, it cannot fully overcome this barrier, prompting further research in this direction. Our findings provide a deeper understanding of MTP applied to pretrained LLMs, informing strategies for accelerating inference through parallel token prediction.

On multi-token prediction for efficient LLM inference

TL;DR

The paper investigates accelerating LLM inference through multi-token prediction (MTP) in models pretrained for next-token prediction (NTP). It demonstrates that MTP is achievable in NTP models by marginalizing over intermediate token probabilities, formalized as ; performance improves with model scale and contextual determinism. It also shows that attaching MTP heads to a frozen backbone is hampered by early hidden-layer specialization, though joint training and strategies like weighted hidden states (WHS) can mitigate the gap without fully closing it. The findings suggest that marginalization-based MTP remains a strong baseline, with MTP pretraining or advanced adaptation methods offering avenues for stronger, more practical acceleration of inference.

Abstract

We systematically investigate multi-token prediction (MTP) capabilities within LLMs pre-trained for next-token prediction (NTP). We first show that such models inherently possess MTP capabilities via numerical marginalization over intermediate token probabilities, though performance is data-dependent and improves with model scale. Furthermore, we explore the challenges of integrating MTP heads into frozen LLMs and find that their hidden layers are strongly specialized for NTP, making adaptation non-trivial. Finally, we show that while joint training of MTP heads with the backbone improves performance, it cannot fully overcome this barrier, prompting further research in this direction. Our findings provide a deeper understanding of MTP applied to pretrained LLMs, informing strategies for accelerating inference through parallel token prediction.

Paper Structure

This paper contains 12 sections, 6 equations, 5 figures, 2 tables.

Figures (5)

  • Figure 1: Top-5 accuracy of MTP using marginalization, for open-ended generation and translation across model families and sizes. MTP capabilities grow with model size and are data dependent.
  • Figure 2: KL divergence between intermediate and final token probabilities. For large models, intermediate layers reach a representation close to the final output relatively early, indicating strong specialization to NTP.
  • Figure 3: Entropy of token probabilities over an example translated sequence for various Pythia models. Larger models tend to be more confident in NTP across majority of the sequence.
  • Figure 4: Average number of considered $x_{t+1}$ tokens -- i.e. tokens in the top 0.99 (solid) of the predicted probability distribution -- for each model size and task during the marginalization analysis. This number consistently decreases with model size, and comparing to the number of tokens in the top-0.9 (dashed) of the distribution further indicates that many of these considered tokens would also have very low probability.
  • Figure 5: a) Overview of the MTP model architecture used for experimentation. The model consists of a LLM backbone, $N$ independent heads which take as input the final hidden states from the backbone, and a shared unembedding applied to each head. Thus, for a given input sequence, we predict $N$ token probabilities with a single forward pass through the backbone. b) When using weighted hidden states, instead of taking the final output of the backbone as input to each head, we instead take the weighted sum of all intermediate hidden states from the backbone. Each head learns its own weight vector, which is normalized before taking the weighted sum. c) Intermediate token probabilities are calculated by taking the hidden states from each layer in the backbone and applying the shared unembedding to obtain token probabilities.