Table of Contents
Fetching ...

Do language models plan ahead for future tokens?

Wilson Wu, John X. Morris, Lionel Levine

TL;DR

The paper addresses whether language models intentionally plan ahead by storing information at time $t$ that benefits future tokens, formalizing two hypotheses: pre-caching and breadcrumbs. It introduces myopic training to suppress gradient flow to past timesteps and uses synthetic data to demonstrate clear pre-caching, while natural-language experiments with GPT-2 indicate breadcrumbs dominate at small scale; scaling up to larger models increases pre-caching. The findings suggest a predominantly breadcrumb-driven pattern in small models, with scalable pre-caching emerging as models grow, indicating a form of future planning in large transformers. These insights have implications for interpretability and safety, and point to avenues for controlling or leveraging future-token planning in practice.

Abstract

Do transformers "think ahead" during inference at a given position? It is known transformers prepare information in the hidden states of the forward pass at time step $t$ that is then used in future forward passes $t+τ$. We posit two explanations for this phenomenon: pre-caching, in which off-diagonal gradient terms present during training result in the model computing features at $t$ irrelevant to the present inference task but useful for the future, and breadcrumbs, in which features most relevant to time step $t$ are already the same as those that would most benefit inference at time $t+τ$. We test these hypotheses by training language models without propagating gradients to past timesteps, a scheme we formalize as myopic training. In a constructed synthetic data setting, we find clear evidence for pre-caching. In the autoregressive language modeling setting, our experiments are more suggestive of the breadcrumbs hypothesis, though pre-caching increases with model scale.

Do language models plan ahead for future tokens?

TL;DR

The paper addresses whether language models intentionally plan ahead by storing information at time that benefits future tokens, formalizing two hypotheses: pre-caching and breadcrumbs. It introduces myopic training to suppress gradient flow to past timesteps and uses synthetic data to demonstrate clear pre-caching, while natural-language experiments with GPT-2 indicate breadcrumbs dominate at small scale; scaling up to larger models increases pre-caching. The findings suggest a predominantly breadcrumb-driven pattern in small models, with scalable pre-caching emerging as models grow, indicating a form of future planning in large transformers. These insights have implications for interpretability and safety, and point to avenues for controlling or leveraging future-token planning in practice.

Abstract

Do transformers "think ahead" during inference at a given position? It is known transformers prepare information in the hidden states of the forward pass at time step that is then used in future forward passes . We posit two explanations for this phenomenon: pre-caching, in which off-diagonal gradient terms present during training result in the model computing features at irrelevant to the present inference task but useful for the future, and breadcrumbs, in which features most relevant to time step are already the same as those that would most benefit inference at time . We test these hypotheses by training language models without propagating gradients to past timesteps, a scheme we formalize as myopic training. In a constructed synthetic data setting, we find clear evidence for pre-caching. In the autoregressive language modeling setting, our experiments are more suggestive of the breadcrumbs hypothesis, though pre-caching increases with model scale.
Paper Structure (39 sections, 7 theorems, 52 equations, 13 figures, 5 tables, 2 algorithms)

This paper contains 39 sections, 7 theorems, 52 equations, 13 figures, 5 tables, 2 algorithms.

Key Result

Theorem 8

Assume $\ell\colon \Theta^n\to\mathbb{R}$ is $\sigma$-strongly convex and $L$-smooth for some $\sigma,L>0$. Consider ordinary gradient descent with untied weights Then, for ${\bm{\theta}}_1^*,\ldots,{\bm{\theta}}_n^*=\mathop{\mathrm{arg\,min}}\limits_{{\bm{\theta}}_1,\ldots,{\bm{\theta}}_n}\ell({\bm{\theta}}_1,\ldots,{\bm{\theta}}_n)$, for small enough $\eta>0$,

Figures (13)

  • Figure 1: At which position is the computation required to correctly answer this math problem taking place? Cognitive science tells us that humans think ahead while speaking; we investigate the extent to which language models do the same.
  • Figure 2: Empirical $R^2$ between linear probes fit on each layer of vanilla transformer models trained on $\mathcal{D}_p$ for $p\in\{0.01,0.1,0.3,1\}$ to targets $\sin(b{{\textnormal{x}}}_{n-i})$. Computed over $50000.0$ samples from $\mathcal{D}_1$.
  • Figure 4: Cross-entropy loss of vanilla and myopic GPT-2 models by token position, and their difference. Evaluated on a sliding window over a 100K-token sample text from the PG-19 dataset rae19. Aggregate cross-entropy losses on this sample are 4.67 (vanilla) and 4.77 (myopic).
  • Figure 5: Benchmarks of Pythia models fine-tuned on the Pile dateset using vanilla and myopic descent.
  • Figure 6: Estimate of $\sin(b{{\textnormal{x}}}_n)$ by linear probe fit on layer 1 of transformer with vanilla training on $\mathcal{D}_{0.3}$. Computed over $50000.0$ samples from $\mathcal{D}_1$.
  • ...and 8 more figures

Theorems & Definitions (21)

  • Definition 1
  • Definition 2
  • Definition 3
  • Definition 4
  • Definition 5
  • Definition 6
  • Definition 7
  • Theorem 8
  • proof
  • Theorem 9
  • ...and 11 more