Table of Contents
Fetching ...

Next-Latent Prediction Transformers Learn Compact World Models

Jayden Teoh, Manan Tomar, Kwangjun Ahn, Edward S. Hu, Pratyusha Sharma, Riashat Islam, Alex Lamb, John Langford

TL;DR

This work tackles the tendency of autoregressive transformers to undercompress history, which harms generalization. It introduces Next-Latent Prediction (NextLat), a lightweight auxiliary objective that trains a latent-transition model to predict the transformer's next latent state given the next token, thereby shaping latent beliefs without changing architecture or inference. The authors prove that NextLat induces belief-state-like representations and demonstrate strong empirical gains across world modeling, reasoning, planning, and language modeling benchmarks, achieving more compact and predictive internal representations. The approach offers a practical, scalable path to improving generalization in autoregressive sequence modeling, with broad potential for pretraining, fine-tuning, and integration with hybrid architectures.

Abstract

Transformers replace recurrence with a memory that grows with sequence length and self-attention that enables ad-hoc look ups over past tokens. Consequently, they lack an inherent incentive to compress history into compact latent states with consistent transition rules. This often leads to learning solutions that generalize poorly. We introduce Next-Latent Prediction (NextLat), which extends standard next-token training with self-supervised predictions in the latent space. Specifically, NextLat trains a transformer to learn latent representations that are predictive of its next latent state given the next output token. Theoretically, we show that these latents provably converge to belief states, compressed information of the history necessary to predict the future. This simple auxiliary objective also injects a recurrent inductive bias into transformers, while leaving their architecture, parallel training, and inference unchanged. NextLat effectively encourages the transformer to form compact internal world models with its own belief states and transition dynamics -- a crucial property absent in standard next-token prediction transformers. Empirically, across benchmarks targeting core sequence modeling competencies -- world modeling, reasoning, planning, and language modeling -- NextLat demonstrates significant gains over standard next-token training in downstream accuracy, representation compression, and lookahead planning. NextLat stands as a simple and efficient paradigm for shaping transformer representations toward stronger generalization.

Next-Latent Prediction Transformers Learn Compact World Models

TL;DR

This work tackles the tendency of autoregressive transformers to undercompress history, which harms generalization. It introduces Next-Latent Prediction (NextLat), a lightweight auxiliary objective that trains a latent-transition model to predict the transformer's next latent state given the next token, thereby shaping latent beliefs without changing architecture or inference. The authors prove that NextLat induces belief-state-like representations and demonstrate strong empirical gains across world modeling, reasoning, planning, and language modeling benchmarks, achieving more compact and predictive internal representations. The approach offers a practical, scalable path to improving generalization in autoregressive sequence modeling, with broad potential for pretraining, fine-tuning, and integration with hybrid architectures.

Abstract

Transformers replace recurrence with a memory that grows with sequence length and self-attention that enables ad-hoc look ups over past tokens. Consequently, they lack an inherent incentive to compress history into compact latent states with consistent transition rules. This often leads to learning solutions that generalize poorly. We introduce Next-Latent Prediction (NextLat), which extends standard next-token training with self-supervised predictions in the latent space. Specifically, NextLat trains a transformer to learn latent representations that are predictive of its next latent state given the next output token. Theoretically, we show that these latents provably converge to belief states, compressed information of the history necessary to predict the future. This simple auxiliary objective also injects a recurrent inductive bias into transformers, while leaving their architecture, parallel training, and inference unchanged. NextLat effectively encourages the transformer to form compact internal world models with its own belief states and transition dynamics -- a crucial property absent in standard next-token prediction transformers. Empirically, across benchmarks targeting core sequence modeling competencies -- world modeling, reasoning, planning, and language modeling -- NextLat demonstrates significant gains over standard next-token training in downstream accuracy, representation compression, and lookahead planning. NextLat stands as a simple and efficient paradigm for shaping transformer representations toward stronger generalization.

Paper Structure

This paper contains 45 sections, 2 theorems, 16 equations, 11 figures, 3 tables, 1 algorithm.

Key Result

Theorem 3.2

Consider the joint learning of three components: NextLat optimizes for the following consistency objectives: where the right-hand side of eq:transition_correctness is the transition law induced by the transformer's weightsWe adopt a probabilistic formulation to retain generality with respect to stochastic transformer models, e.g. fleuret2025freetransformer.. For these consistency objectives to b

Figures (11)

  • Figure 1: Reconstructed maps from sequences generated by transformers trained on Manhattan taxi rides vafa2024evaluating with three methods: next-token prediction (GPT), joint multi-token prediction (JTP), and next-latent prediction (NextLat). Generated edges consistent with the true world model are colored black; invalid edges are red. Visibly, the transformer trained with next-latent prediction learns a world model more consistent with reality.
  • Figure 2: Illustration of different predictive mechanisms at time step $t=3$. Other methods supervise only the token-level emissions, leaving intermediate latent representations implicit. In contrast, NextLat explicitly learns latent dynamics that predicts hidden state $\hat{h}_{t+1}$ from $(h_t, x_{t+1})$. Token-level supervision is then applied to the $\hat{h}_{t+1}$. Therefore, accurate multi-token predictions (beyond the next token) emerge as a consequence of faithful latent dynamics modeling, with the latent acting as the bottleneck.
  • Figure 3: Reconstructed maps from transformers trained on Manhattan taxi rides using different objectives.
  • Figure 4: Performance on Countdown. Best result is bolded, and second best is underlined.
  • Figure 5: Validity of equations (i.e., LHS = RHS) generated on Countdown. All models in this plot use $d=1$.
  • ...and 6 more figures

Theorems & Definitions (5)

  • Definition 3.1: Belief states in sequence modeling
  • Theorem 3.2
  • Definition A.1: $k$-observability for sequences
  • Proposition A.2: JTP forms belief states in $k$-observable systems
  • proof