On multi-token prediction for efficient LLM inference
Somesh Mehra, Javier Alonso Garcia, Lukas Mauch
TL;DR
The paper investigates accelerating LLM inference through multi-token prediction (MTP) in models pretrained for next-token prediction (NTP). It demonstrates that MTP is achievable in NTP models by marginalizing over intermediate token probabilities, formalized as $p(\mathcal{X}_{t:t+K}|\mathcal{X}_{\le t}; \theta)$; performance improves with model scale and contextual determinism. It also shows that attaching MTP heads to a frozen backbone is hampered by early hidden-layer specialization, though joint training and strategies like weighted hidden states (WHS) can mitigate the gap without fully closing it. The findings suggest that marginalization-based MTP remains a strong baseline, with MTP pretraining or advanced adaptation methods offering avenues for stronger, more practical acceleration of inference.
Abstract
We systematically investigate multi-token prediction (MTP) capabilities within LLMs pre-trained for next-token prediction (NTP). We first show that such models inherently possess MTP capabilities via numerical marginalization over intermediate token probabilities, though performance is data-dependent and improves with model scale. Furthermore, we explore the challenges of integrating MTP heads into frozen LLMs and find that their hidden layers are strongly specialized for NTP, making adaptation non-trivial. Finally, we show that while joint training of MTP heads with the backbone improves performance, it cannot fully overcome this barrier, prompting further research in this direction. Our findings provide a deeper understanding of MTP applied to pretrained LLMs, informing strategies for accelerating inference through parallel token prediction.
