Next-Latent Prediction Transformers Learn Compact World Models
Jayden Teoh, Manan Tomar, Kwangjun Ahn, Edward S. Hu, Pratyusha Sharma, Riashat Islam, Alex Lamb, John Langford
TL;DR
This work tackles the tendency of autoregressive transformers to undercompress history, which harms generalization. It introduces Next-Latent Prediction (NextLat), a lightweight auxiliary objective that trains a latent-transition model to predict the transformer's next latent state given the next token, thereby shaping latent beliefs without changing architecture or inference. The authors prove that NextLat induces belief-state-like representations and demonstrate strong empirical gains across world modeling, reasoning, planning, and language modeling benchmarks, achieving more compact and predictive internal representations. The approach offers a practical, scalable path to improving generalization in autoregressive sequence modeling, with broad potential for pretraining, fine-tuning, and integration with hybrid architectures.
Abstract
Transformers replace recurrence with a memory that grows with sequence length and self-attention that enables ad-hoc look ups over past tokens. Consequently, they lack an inherent incentive to compress history into compact latent states with consistent transition rules. This often leads to learning solutions that generalize poorly. We introduce Next-Latent Prediction (NextLat), which extends standard next-token training with self-supervised predictions in the latent space. Specifically, NextLat trains a transformer to learn latent representations that are predictive of its next latent state given the next output token. Theoretically, we show that these latents provably converge to belief states, compressed information of the history necessary to predict the future. This simple auxiliary objective also injects a recurrent inductive bias into transformers, while leaving their architecture, parallel training, and inference unchanged. NextLat effectively encourages the transformer to form compact internal world models with its own belief states and transition dynamics -- a crucial property absent in standard next-token prediction transformers. Empirically, across benchmarks targeting core sequence modeling competencies -- world modeling, reasoning, planning, and language modeling -- NextLat demonstrates significant gains over standard next-token training in downstream accuracy, representation compression, and lookahead planning. NextLat stands as a simple and efficient paradigm for shaping transformer representations toward stronger generalization.
