Table of Contents
Fetching ...

From Memories to Maps: Mechanisms of In-Context Reinforcement Learning in Transformers

Ching Fang, Kanaka Rajan

TL;DR

This work investigates how episodic memory enables rapid in-context reinforcement learning in transformers by training decision-pretrained transformers on gridworld and tree-maze planning tasks. It shows that rapid adaptation arises from memory-based computations stored in context-memory tokens, not from traditional model-free or model-based strategies, and that representations exhibit in-context structure learning with cross-context alignment. Through rigorous attribution and decoding analyses, the authors demonstrate that decisions rely on cached intermediate computations rather than full path planning, and that memory tokens encode coordinates and path information in a hippocampal-like manner. The findings propose memory as a computational workspace supporting flexible, rapid adaptation in novel environments, with implications for understanding natural cognition and guiding future memory-augmented architectures.

Abstract

Humans and animals show remarkable learning efficiency, adapting to new environments with minimal experience. This capability is not well captured by standard reinforcement learning algorithms that rely on incremental value updates. Rapid adaptation likely depends on episodic memory -- the ability to retrieve specific past experiences to guide decisions in novel contexts. Transformers provide a useful setting for studying these questions because of their ability to learn rapidly in-context and because their key-value architecture resembles episodic memory systems in the brain. We train a transformer to in-context reinforcement learn in a distribution of planning tasks inspired by rodent behavior. We then characterize the learning algorithms that emerge in the model. We first find that representation learning is supported by in-context structure learning and cross-context alignment, where representations are aligned across environments with different sensory stimuli. We next demonstrate that the reinforcement learning strategies developed by the model are not interpretable as standard model-free or model-based planning. Instead, we show that in-context reinforcement learning is supported by caching intermediate computations within the model's memory tokens, which are then accessed at decision time. Overall, we find that memory may serve as a computational resource, storing both raw experience and cached computations to support flexible behavior. Furthermore, the representations developed in the model resemble computations associated with the hippocampal-entorhinal system in the brain, suggesting that our findings may be relevant for natural cognition. Taken together, our work offers a mechanistic hypothesis for the rapid adaptation that underlies in-context learning in artificial and natural settings.

From Memories to Maps: Mechanisms of In-Context Reinforcement Learning in Transformers

TL;DR

This work investigates how episodic memory enables rapid in-context reinforcement learning in transformers by training decision-pretrained transformers on gridworld and tree-maze planning tasks. It shows that rapid adaptation arises from memory-based computations stored in context-memory tokens, not from traditional model-free or model-based strategies, and that representations exhibit in-context structure learning with cross-context alignment. Through rigorous attribution and decoding analyses, the authors demonstrate that decisions rely on cached intermediate computations rather than full path planning, and that memory tokens encode coordinates and path information in a hippocampal-like manner. The findings propose memory as a computational workspace supporting flexible, rapid adaptation in novel environments, with implications for understanding natural cognition and guiding future memory-augmented architectures.

Abstract

Humans and animals show remarkable learning efficiency, adapting to new environments with minimal experience. This capability is not well captured by standard reinforcement learning algorithms that rely on incremental value updates. Rapid adaptation likely depends on episodic memory -- the ability to retrieve specific past experiences to guide decisions in novel contexts. Transformers provide a useful setting for studying these questions because of their ability to learn rapidly in-context and because their key-value architecture resembles episodic memory systems in the brain. We train a transformer to in-context reinforcement learn in a distribution of planning tasks inspired by rodent behavior. We then characterize the learning algorithms that emerge in the model. We first find that representation learning is supported by in-context structure learning and cross-context alignment, where representations are aligned across environments with different sensory stimuli. We next demonstrate that the reinforcement learning strategies developed by the model are not interpretable as standard model-free or model-based planning. Instead, we show that in-context reinforcement learning is supported by caching intermediate computations within the model's memory tokens, which are then accessed at decision time. Overall, we find that memory may serve as a computational resource, storing both raw experience and cached computations to support flexible behavior. Furthermore, the representations developed in the model resemble computations associated with the hippocampal-entorhinal system in the brain, suggesting that our findings may be relevant for natural cognition. Taken together, our work offers a mechanistic hypothesis for the rapid adaptation that underlies in-context learning in artificial and natural settings.

Paper Structure

This paper contains 27 sections, 1 equation, 23 figures.

Figures (23)

  • Figure 1: A transformer is trained to in-context reinforcement learn in diverse planning tasks.A. Diagram of meta-learning setup. For each task, the model is trained via supervision to predict the optimal action from a query state $s_{\text{query}}$, given memories of RL transition tuples sampled in-context. B. Illustration of three training tasks (orange) and one test task (gray) from the gridworld distribution. In each task, the underlying graph structure is fixed, but the reward location (red star) can vary. Each state is encoded as a random Gaussian vector (bottom). Importantly, test task state encodings are novel. C. As in (B), but for the tree maze distribution. The training set graph structures are drawn from probabilistically branching trees, while the test set structure is a full binary tree.
  • Figure 2: Transformers can rapidly learn and plan in new tasks.A. Average max-normalized return in two held-out gridworld environments as a function of context length. For each context length, $20$ query states are sampled with test horizon $15$. B. As in (A), but return is plotted against the number of rewards experienced in-context and averaged over $50$ held-out environments. Blue: meta-learned transformer; Green: tabular Q-learning; Pink: DQN. C. Example of shortcut behavior in a held-out gridworld. The model experiences a circuitous trajectory (orange), but can infer a more efficient path (blue). D., E. As in (A, B), but for tree mazes and test horizon $100$. F. As in (D,E), but shown only for context length $800$ and subdivided by query type: states seen before reward (Pre-$\star$), after reward (Post-$\star$), or never seen in context (Novel; not used in E). Error bars show 95% C.I.
  • Figure 3: Model representations are shaped by in-context structure learning.A. An example test gridworld. Query token representations are visualized for each state after projection onto the first two principal components, for layer 1 (left), layer 2 (middle), and layer 3 (right). Context length is 1. Points are colored by graph location; gray lines indicate true connectivity. Reward is marked with a red star. B. As in (A), but with context length 250. C. Kernel alignment between model representations and latent graph structure as a function of context length, across $100$ environments. Shading shows 95% C.I.; colors denote model layer. Dashed line shows baseline from raw inputs. D. As in (C), but for context length 250, with reward ablation (shaded bars). E, F. As in (A, B), but for test tree mazes. Points are colored by maze depth. G, H. As in (C, D), but for test tree mazes.
  • Figure 4: As context grows, representations across environments with similar structure are aligned.A. Diagram of cross-environment alignment whittington2020tolman. Although sensory inputs differ, environments share latent structure, and representations of matching latent states should be similar. B. Average pairwise Pearson correlation coefficient of node representations across $100$ gridworld environments, as a function of context length. Solid lines: same-node comparisons. Dashed lines: different-node comparisons. Shading shows 95% C.I.. Line color denotes model layer. C. PCA visualization of representations pooled from 15 randomly selected visualizations. D. As in (B), but for $50$ tree mazes. E. As in (C), but for tree mazes. F. Summary of (D) at context length 1600, averaged across layers. X-axis denotes the maze-layer of the comparison node.
  • Figure 5: Memory retrieval at decision time shows limited expansion from the query state and the goal state.A. Example tree maze environment, context length 800. Edge color indicates gradient attribution strength for each transition. B. Average attribution strength of each context memory vs distance from query state and goal state, across 50 environments.
  • ...and 18 more figures