Table of Contents
Fetching ...

Thinking into the Future: Latent Lookahead Training for Transformers

Lorenzo Noci, Gregor Bachmann, Seyed-Mohsen Moosavi-Dezfooli, Moin Nabi

Abstract

Autoregressive language models trained with next-token prediction generate text by sampling one discrete token at a time. Although very scalable, this objective forces the model to commit at every step, preventing it from exploring or reflecting upon multiple plausible continuations. Furthermore, the compute allocation across tokens is uniform; every token is formed based on a single forward-pass, potentially limiting the model's expressiveness in cases where difficult tokens require inherently more compute. Towards addressing these limitations, we introduce latent lookahead, a training strategy that enables models to "think" before generating: at selected positions in the sequence, before committing to the next token, the model performs a multi-step lookahead in latent space. More precisely, instead of sampling future tokens, we leverage the network's latent space by recursively feeding its hidden states back into the context for $τ$ steps, investing more compute on predicting that token. This produces $τ$ latent predictions that are supervised against the next $τ$ ground-truth tokens, encouraging the model to "lookahead" and refine its prediction. We show that latent lookahead substantially outperforms both autoregressive and non-autoregressive baselines on planning tasks such as maze solving, Sudoku, and ProsQA, where foresight is essential.

Thinking into the Future: Latent Lookahead Training for Transformers

Abstract

Autoregressive language models trained with next-token prediction generate text by sampling one discrete token at a time. Although very scalable, this objective forces the model to commit at every step, preventing it from exploring or reflecting upon multiple plausible continuations. Furthermore, the compute allocation across tokens is uniform; every token is formed based on a single forward-pass, potentially limiting the model's expressiveness in cases where difficult tokens require inherently more compute. Towards addressing these limitations, we introduce latent lookahead, a training strategy that enables models to "think" before generating: at selected positions in the sequence, before committing to the next token, the model performs a multi-step lookahead in latent space. More precisely, instead of sampling future tokens, we leverage the network's latent space by recursively feeding its hidden states back into the context for steps, investing more compute on predicting that token. This produces latent predictions that are supervised against the next ground-truth tokens, encouraging the model to "lookahead" and refine its prediction. We show that latent lookahead substantially outperforms both autoregressive and non-autoregressive baselines on planning tasks such as maze solving, Sudoku, and ProsQA, where foresight is essential.
Paper Structure (22 sections, 8 equations, 8 figures, 8 tables, 1 algorithm)

This paper contains 22 sections, 8 equations, 8 figures, 8 tables, 1 algorithm.

Figures (8)

  • Figure 1: Standard autoregressive inference vs latent lookahead. Left: in standard next token prediction, the model samples from the hidden state of the latest generated token after applying the final unembedding head, and appends the generated token to the context. Right: in our approach, the model enters the latent lookahead thinking, where the hidden states are fed directly into the context instead of sampled visible tokens. This procedure is repeated $\tau$ times, and only then the visible token is sampled from the first latent position. In the figure above, the tokens $x_2$ and $x_4$ are selected, $\tau=3$, and $z_{i,j}$ indicates the $j$-th latent token relative to the $i$-th visible token. See also Fig. \ref{['fig:training-and-mask']}
  • Figure 2: Lookahead behaviour when solving a Sudoku. In the first slot, both $1$ and $3$ are viable options. However, when thinking ahead to the second empty slot, where $3$ is the only plausible entry, it is easy to realize that $1$ is the right choice for the first slot.
  • Figure 3: Left: the visible tokens (in blue) are supervised with standard next token prediction objective. The latent tokens are supervised explicitly with an equivalent number of steps ahead, i.e. the $z_{i,j}$ latent token is supervised with the $x_{i+j}$ visible token. Right: to ensure that the model is able to refine the next token prediction based on its own assessment of the future steps, we allow for full (non-causal) masking between the latent tokens. The visible tokens follow the standard causal masking. Finally, the latent thoughts causally attend to previous visible, but not to the previous latent. This allows for parallel generation of the latent thoughts during training.
  • Figure 4: Effect of increasing the number of latent tokens in three tasks. As the compute allocated increases, the <pause> token method either saturates or fails to improve over the NTP baseline. In contrast, latent lookahead continues to benefit from larger $\tau$.
  • Figure 5: (a): Ablation comparing latent lookahead with Looped variant and MTP (mini and full Sudoku). (b): Comparison between sequential and random latent position strategies on mini Sudoku .
  • ...and 3 more figures