Table of Contents
Fetching ...

Efficient Training-Free Multi-Token Prediction via Embedding-Space Probing

Raghavv Goel, Mukul Gagrani, Mingu Lee, Chris Lott

Abstract

Large language models (LLMs) exhibit latent multi-token prediction (MTP) capabilities despite being trained solely for next-token generation. We propose a simple, training-free MTP approach that probes an LLM using on-the-fly mask tokens drawn from its embedding space, enabling parallel prediction of future tokens without modifying model weights or relying on auxiliary draft models. Our method constructs a speculative token tree by sampling top-K candidates from mask-token logits and applies a lightweight pruning strategy to retain high-probability continuations. During decoding, candidate predictions are verified in parallel, resulting in lossless generation while substantially reducing the number of model calls and improving token throughput. Across benchmarks, our probing-based MTP consistently outperforms existing training-free baselines, increasing acceptance length by approximately 12\% on LLaMA3 and 8--12\% on Qwen3, and achieving throughput gains of up to 15--19\%. Finally, we provide theoretical insights and empirical evidence showing that decoder layers naturally align mask-token representations with next-token states, enabling accurate multi-step prediction without retraining or auxiliary models.

Efficient Training-Free Multi-Token Prediction via Embedding-Space Probing

Abstract

Large language models (LLMs) exhibit latent multi-token prediction (MTP) capabilities despite being trained solely for next-token generation. We propose a simple, training-free MTP approach that probes an LLM using on-the-fly mask tokens drawn from its embedding space, enabling parallel prediction of future tokens without modifying model weights or relying on auxiliary draft models. Our method constructs a speculative token tree by sampling top-K candidates from mask-token logits and applies a lightweight pruning strategy to retain high-probability continuations. During decoding, candidate predictions are verified in parallel, resulting in lossless generation while substantially reducing the number of model calls and improving token throughput. Across benchmarks, our probing-based MTP consistently outperforms existing training-free baselines, increasing acceptance length by approximately 12\% on LLaMA3 and 8--12\% on Qwen3, and achieving throughput gains of up to 15--19\%. Finally, we provide theoretical insights and empirical evidence showing that decoder layers naturally align mask-token representations with next-token states, enabling accurate multi-step prediction without retraining or auxiliary models.
Paper Structure (30 sections, 2 theorems, 24 equations, 8 figures, 12 tables, 1 algorithm)

This paper contains 30 sections, 2 theorems, 24 equations, 8 figures, 12 tables, 1 algorithm.

Key Result

Lemma 3.1

Let $h_m, h_v \in \mathbb{R}^d$ be hidden states for the mask token and the next-true token after the last decoder layer and let $W \in \mathbb{R}^{d\times V}$ be the LM head with columns $w_r\in \mathbb{R}^{d}$. Assume $||h_m||_{2}, ||h_v||_{2} \leq c_h$ and $||w_r||_{2}\leq c_w , \forall \, r$. We

Figures (8)

  • Figure 1: (Left) Standard next-token prediction setup for autoregressive models, (middle) multi-token prediction during prefill-stage by probing mask tokens which are appended to prompt tokens, (right) multi-token prediction with parallel verification and generation. Mask tokens are associated with last generated token ($x_{s}$) and future tokens ($\hat{x}_{s+1}, \hat{x}_{s+2}$) through custom tree attention mask.
  • Figure 2: We use Dolly-Databricks (creative-writing) DatabricksBlog2023DollyV2 samples (100) to measure average cosine similarity across layers for mask and true-future token hidden-states. For Llama3.2-3B-Instruct, higher cosine similarity in later layers (15 onwards) correlates with token acceptance (green), while lower similarity correlates with rejection (red).
  • Figure 3: Mask tokens are present for each input token: last generated (blue) and future (orange) tokens, when processed by model all tokens are flattened and mask tokens are placed at the end with appropriate position indices.
  • Figure 4: Evaluation on Spec-Bench using LLaMA3.1-8B-Instruct across block complexities (BC = 10, 30, 60). Our method (green) consistently achieves the highest average accepted tokens across most tasks and BC settings.
  • Figure 5: Evaluation on SpecBench using Qwen3-32B across block complexities (BC = 10, 30, 60). Our method (green) consistently achieves the highest average accepted token across most tasks and BC settings.
  • ...and 3 more figures

Theorems & Definitions (3)

  • Lemma 3.1
  • Lemma 1.1
  • proof