Efficient Joint Prediction of Multiple Future Tokens
Kwangjun Ahn, Alex Lamb, John Langford
TL;DR
Efficient Joint Prediction of Multiple Future Tokens introduces joint multi-token prediction (JTP), a lightweight extension that enriches Transformer hidden states by jointly predicting a short sequence of future tokens through a representational bottleneck and a Fetch module guided by teacher forcing. By factorizing the joint future loss and constraining information flow, JTP yields a short-horizon belief state that improves representation quality without substantial overhead. The approach outperforms existing multi-token prediction methods on the star graph navigation task, demonstrating robust gains across small and mid-sized graphs while maintaining a compact parameter footprint. A preliminary language-modeling sanity check suggests compatibility with standard next-token objectives, motivating further exploration of JTP in broader text-modeling settings. Overall, JTP provides a principled, efficient mechanism to imbue models with forward-looking predictive information that can bolster planning-like reasoning in sequence tasks.
Abstract
In this short report, we introduce joint multi-token prediction (JTP), a lightweight modification of standard next-token prediction designed to enrich hidden state representations by jointly predicting multiple future tokens. Unlike previous multi-token prediction approaches, JTP strategically employs teacher forcing of future-tokens through a carefully designed representation bottleneck, allowing the model to encode rich predictive information with minimal computational overhead during training. We show that the JTP approach achieves a short-horizon belief state representation, while popular alternatives for multi-token prediction fail to do so. We demonstrate the effectiveness of our method on the synthetic star graph navigation task from from Bachmann and Nagarajan [2024], highlighting a significant performance improvement over existing methods. This manuscript presents promising preliminary results intended to stimulate further research.
