Table of Contents
Fetching ...

Transformer Mechanisms Mimic Frontostriatal Gating Operations When Trained on Human Working Memory Tasks

Aaron Traylor, Jack Merullo, Michael J. Frank, Ellie Pavlick

TL;DR

The paper addresses whether vanilla Transformer models can develop brain-like gating mechanisms when trained on working-memory tasks. It trains a small, two-layer decoder-only Transformer on a textual reference-back task and uses mechanistic interpretability (path-patching) to identify emergent gating policies. The results show that input gating arises via key vectors and output gating via query vectors, with robust gating policies correlating with high task performance; the study also observes phase-transition–like improvements during training. These findings provide a concrete link between AI architectures and human memory processes, suggesting new directions for brain-inspired analyses of modern neural networks.

Abstract

Models based on the Transformer neural network architecture have seen success on a wide variety of tasks that appear to require complex "cognitive branching" -- or the ability to maintain pursuit of one goal while accomplishing others. In cognitive neuroscience, success on such tasks is thought to rely on sophisticated frontostriatal mechanisms for selective \textit{gating}, which enable role-addressable updating -- and later readout -- of information to and from distinct "addresses" of memory, in the form of clusters of neurons. However, Transformer models have no such mechanisms intentionally built-in. It is thus an open question how Transformers solve such tasks, and whether the mechanisms that emerge to help them to do so bear any resemblance to the gating mechanisms in the human brain. In this work, we analyze the mechanisms that emerge within a vanilla attention-only Transformer trained on a simple sequence modeling task inspired by a task explicitly designed to study working memory gating in computational cognitive neuroscience. We find that, as a result of training, the self-attention mechanism within the Transformer specializes in a way that mirrors the input and output gating mechanisms which were explicitly incorporated into earlier, more biologically-inspired architectures. These results suggest opportunities for future research on computational similarities between modern AI architectures and models of the human brain.

Transformer Mechanisms Mimic Frontostriatal Gating Operations When Trained on Human Working Memory Tasks

TL;DR

The paper addresses whether vanilla Transformer models can develop brain-like gating mechanisms when trained on working-memory tasks. It trains a small, two-layer decoder-only Transformer on a textual reference-back task and uses mechanistic interpretability (path-patching) to identify emergent gating policies. The results show that input gating arises via key vectors and output gating via query vectors, with robust gating policies correlating with high task performance; the study also observes phase-transition–like improvements during training. These findings provide a concrete link between AI architectures and human memory processes, suggesting new directions for brain-inspired analyses of modern neural networks.

Abstract

Models based on the Transformer neural network architecture have seen success on a wide variety of tasks that appear to require complex "cognitive branching" -- or the ability to maintain pursuit of one goal while accomplishing others. In cognitive neuroscience, success on such tasks is thought to rely on sophisticated frontostriatal mechanisms for selective \textit{gating}, which enable role-addressable updating -- and later readout -- of information to and from distinct "addresses" of memory, in the form of clusters of neurons. However, Transformer models have no such mechanisms intentionally built-in. It is thus an open question how Transformers solve such tasks, and whether the mechanisms that emerge to help them to do so bear any resemblance to the gating mechanisms in the human brain. In this work, we analyze the mechanisms that emerge within a vanilla attention-only Transformer trained on a simple sequence modeling task inspired by a task explicitly designed to study working memory gating in computational cognitive neuroscience. We find that, as a result of training, the self-attention mechanism within the Transformer specializes in a way that mirrors the input and output gating mechanisms which were explicitly incorporated into earlier, more biologically-inspired architectures. These results suggest opportunities for future research on computational similarities between modern AI architectures and models of the human brain.
Paper Structure (13 sections, 4 figures)

This paper contains 13 sections, 4 figures.

Figures (4)

  • Figure 1: Graphical diagram of the path-patching process. Attention heads are represented as circles (layer,head index), and contextual representations of each token (as well as the next token prediction) are represented as rectangles.
  • Figure 2: Above: example of textual reference-back task as model input. Below: step-by-step task process; models do not view task-internal grey words. "Update Instruction" executes after "Answer" despite appearing earlier sequentially.
  • Figure 3: Model behavior when predicting same/different (token 15) is shown. We measure attention visualized as a shade of purple, with deeper shade corresponding to higher attention to that token. We create "corrupted" minimal pairs in which changing a token (light blue) either changes the correct label at index 15 (examples b, c, e) or does not (d, f). We make small path-patching edits with the minimal pair to targeted network components (layer 1 keys for b, c, d, f; queries for e,f). In other words, we replace specific components (denoted with red text) with their corresponding representation from the "corrupted" sequence, but hold all other representations constant, and run the model and get a new same/different prediction. In all test examples, making the small patch successfully results in the model's prediction changing to align with the "corrupted" example.
  • Figure 4: Model performance over training on patching subtasks. Each row corresponds to an individual model's training loss (solid line) and subtask accuracy (dashed line) over time; blue and orange lines respectively correspond to models which reach 100% test accuracy and to those that do not.