Table of Contents
Fetching ...

Efficient Joint Prediction of Multiple Future Tokens

Kwangjun Ahn, Alex Lamb, John Langford

TL;DR

Efficient Joint Prediction of Multiple Future Tokens introduces joint multi-token prediction (JTP), a lightweight extension that enriches Transformer hidden states by jointly predicting a short sequence of future tokens through a representational bottleneck and a Fetch module guided by teacher forcing. By factorizing the joint future loss and constraining information flow, JTP yields a short-horizon belief state that improves representation quality without substantial overhead. The approach outperforms existing multi-token prediction methods on the star graph navigation task, demonstrating robust gains across small and mid-sized graphs while maintaining a compact parameter footprint. A preliminary language-modeling sanity check suggests compatibility with standard next-token objectives, motivating further exploration of JTP in broader text-modeling settings. Overall, JTP provides a principled, efficient mechanism to imbue models with forward-looking predictive information that can bolster planning-like reasoning in sequence tasks.

Abstract

In this short report, we introduce joint multi-token prediction (JTP), a lightweight modification of standard next-token prediction designed to enrich hidden state representations by jointly predicting multiple future tokens. Unlike previous multi-token prediction approaches, JTP strategically employs teacher forcing of future-tokens through a carefully designed representation bottleneck, allowing the model to encode rich predictive information with minimal computational overhead during training. We show that the JTP approach achieves a short-horizon belief state representation, while popular alternatives for multi-token prediction fail to do so. We demonstrate the effectiveness of our method on the synthetic star graph navigation task from from Bachmann and Nagarajan [2024], highlighting a significant performance improvement over existing methods. This manuscript presents promising preliminary results intended to stimulate further research.

Efficient Joint Prediction of Multiple Future Tokens

TL;DR

Efficient Joint Prediction of Multiple Future Tokens introduces joint multi-token prediction (JTP), a lightweight extension that enriches Transformer hidden states by jointly predicting a short sequence of future tokens through a representational bottleneck and a Fetch module guided by teacher forcing. By factorizing the joint future loss and constraining information flow, JTP yields a short-horizon belief state that improves representation quality without substantial overhead. The approach outperforms existing multi-token prediction methods on the star graph navigation task, demonstrating robust gains across small and mid-sized graphs while maintaining a compact parameter footprint. A preliminary language-modeling sanity check suggests compatibility with standard next-token objectives, motivating further exploration of JTP in broader text-modeling settings. Overall, JTP provides a principled, efficient mechanism to imbue models with forward-looking predictive information that can bolster planning-like reasoning in sequence tasks.

Abstract

In this short report, we introduce joint multi-token prediction (JTP), a lightweight modification of standard next-token prediction designed to enrich hidden state representations by jointly predicting multiple future tokens. Unlike previous multi-token prediction approaches, JTP strategically employs teacher forcing of future-tokens through a carefully designed representation bottleneck, allowing the model to encode rich predictive information with minimal computational overhead during training. We show that the JTP approach achieves a short-horizon belief state representation, while popular alternatives for multi-token prediction fail to do so. We demonstrate the effectiveness of our method on the synthetic star graph navigation task from from Bachmann and Nagarajan [2024], highlighting a significant performance improvement over existing methods. This manuscript presents promising preliminary results intended to stimulate further research.

Paper Structure

This paper contains 15 sections, 1 theorem, 13 equations, 5 figures, 2 tables.

Key Result

Theorem 4.1

Suppose the main Transformer has $L$ layers, and let the input sequence length be $T$. Consider a JTP with depth $D$. Then the following holds: Consequently, the flops per gradient scale on the order of $O\bigl(\tfrac{T\,L}{D} \;+\; D\bigr)$.

Figures (5)

  • Figure 1: Illustration of multi-token prediction mechanisms at position $t=3$. The method of gloeckle2024better (left) independently predicts future tokens through a single representation bottleneck, neglecting dependencies between the future tokens. DeepSeek-V3 deepseekai2024v3 (middle) processes each token prediction through multiple layers considering the entire historical context, bypassing the desired bottleneck and thus diminishing representation enrichment. Dotted arrows indicate teacher-forcing dependencies. In contrast, our proposed method (right) efficiently funnels predictive information through a single representation bottleneck while utilizing teacher-forced tokens (dotted arrows), thus preserving token dependencies without compromising representation richness or computational efficiency.
  • Figure 2: Illustration of the star graph problem due to bachmann2024pitfalls.
  • Figure 3: Performance comparison of different methods. Our approach consistently solves star graph tasks across different configurations, whereas prior methods struggle, especially for larger graphs.
  • Figure 4: Performance with small prediction windows. Even at minimal depth, our JTP approach outperforms the next-token prediction baseline, demonstrating its effectiveness in shallow-depth settings.
  • Figure 5: Performance on larger graphs. Our approach remains effective, but for $G(7,7)$, test accuracy did not reach $100\%$ within 20,000 training steps.

Theorems & Definitions (2)

  • Theorem 4.1
  • proof : Proof of \ref{['thm:comp']}