Table of Contents
Fetching ...

Weight-sparse transformers have interpretable circuits

Leo Gao, Achyuta Rajaram, Jacob Coxon, Soham V. Govande, Bowen Baker, Dan Mossing

TL;DR

This work introduces weight-sparse transformers to achieve highly interpretable internal circuits, enabling the isolation of compact, task-specific subcircuits that are faithful to model behavior. By enforcing strong $L_0$ weight sparsity (and activation sparsity) and applying a novel pruning method, the authors extract minimal circuits for simple Python-code tasks and validate their faithfulness via mean ablations and targeted perturbations. They demonstrate that sparse models yield ~$16$-fold smaller circuits than dense counterparts at comparable pretraining loss, and that scaling the total parameter count improves the capability-interpretability frontier. Additionally, they extend the approach with bridges to relate sparse circuits to dense models, enabling interpretation without retraining frontier models. The work highlights both the promise and the practical challenges of mechanistic interpretability via sparse circuits, including computational inefficiency and incomplete faithfulness, while suggesting directions toward scalable, automated interpretability frameworks.

Abstract

Finding human-understandable circuits in language models is a central goal of the field of mechanistic interpretability. We train models to have more understandable circuits by constraining most of their weights to be zeros, so that each neuron only has a few connections. To recover fine-grained circuits underlying each of several hand-crafted tasks, we prune the models to isolate the part responsible for the task. These circuits often contain neurons and residual channels that correspond to natural concepts, with a small number of straightforwardly interpretable connections between them. We study how these models scale and find that making weights sparser trades off capability for interpretability, and scaling model size improves the capability-interpretability frontier. However, scaling sparse models beyond tens of millions of nonzero parameters while preserving interpretability remains a challenge. In addition to training weight-sparse models de novo, we show preliminary results suggesting our method can also be adapted to explain existing dense models. Our work produces circuits that achieve an unprecedented level of human understandability and validates them with considerable rigor.

Weight-sparse transformers have interpretable circuits

TL;DR

This work introduces weight-sparse transformers to achieve highly interpretable internal circuits, enabling the isolation of compact, task-specific subcircuits that are faithful to model behavior. By enforcing strong weight sparsity (and activation sparsity) and applying a novel pruning method, the authors extract minimal circuits for simple Python-code tasks and validate their faithfulness via mean ablations and targeted perturbations. They demonstrate that sparse models yield ~-fold smaller circuits than dense counterparts at comparable pretraining loss, and that scaling the total parameter count improves the capability-interpretability frontier. Additionally, they extend the approach with bridges to relate sparse circuits to dense models, enabling interpretation without retraining frontier models. The work highlights both the promise and the practical challenges of mechanistic interpretability via sparse circuits, including computational inefficiency and incomplete faithfulness, while suggesting directions toward scalable, automated interpretability frameworks.

Abstract

Finding human-understandable circuits in language models is a central goal of the field of mechanistic interpretability. We train models to have more understandable circuits by constraining most of their weights to be zeros, so that each neuron only has a few connections. To recover fine-grained circuits underlying each of several hand-crafted tasks, we prune the models to isolate the part responsible for the task. These circuits often contain neurons and residual channels that correspond to natural concepts, with a small number of straightforwardly interpretable connections between them. We study how these models scale and find that making weights sparser trades off capability for interpretability, and scaling model size improves the capability-interpretability frontier. However, scaling sparse models beyond tens of millions of nonzero parameters while preserving interpretability remains a challenge. In addition to training weight-sparse models de novo, we show preliminary results suggesting our method can also be adapted to explain existing dense models. Our work produces circuits that achieve an unprecedented level of human understandability and validates them with considerable rigor.

Paper Structure

This paper contains 66 sections, 4 equations, 41 figures, 2 tables.

Figures (41)

  • Figure 1: An illustration of our overall setup. We first train weight-sparse models. Then, for each of a curated suite of simple behaviors, we prune the model down to the subset of nodes required to perform the task. We ablate nodes by pruning to their mean activation value over the pretraining distribution.
  • Figure 2: Our weight-sparse models learn simpler task-specific circuits than dense models. We examine a sparse model and a dense model with the same pretraining loss. We sweep target loss, and find the size of the minimal circuit in each model that can achieve that loss, averaged across tasks. Sparse model circuits are roughly 16-fold smaller at any given loss.
  • Figure 3: Scaling the total parameter count of weight-sparse models improves the capability-interpretability Pareto frontier. Making models sparser (i.e. decreasing the $L_0$ norm of weights) while holding total parameter count fixed trades off the two, harming capability but improving interpretability. We define capability as pretraining loss; see \ref{['sec:measuring_interpretability']} for our definition of interpretability. Down and to the left is better.
  • Figure 4: The string closing circuit. We omit no detail, showing all 12 nodes and 9 edges needed to complete the task near perfectly. First, 0.mlp converts the token embeddings into "quote detector" and "quote type classifier" residual channels, which are read by key and value channels respectively in 10.attn. Subsequent tokens attend to the key and copy the value to predict the corresponding closing quote. In the diagram, the vertical bundle of lines under each input token is its residual stream. Activations of important nodes on cherry-picked task examples are shown on the left. Dashed horizontal lines mark layer boundaries. $\otimes$ denotes scalar multiplication; directly merged lines denote scalar addition. Black numbers indicate channel or neuron indices. Red and blue numbers mark positive and negative weights (or biases). This diagram only shows the relevant attention path. Inactive parts of the circuit are greyed out, and irrelevant layers are omitted.
  • Figure 5: A simplified illustration of the circuit for counting nesting depth, using the conventions from \ref{['fig:single_double_diagram']}. A single attention value channel functions as an "open bracket detector" derived from the embedding of the token [. The attention head then averages the value of this detector over the context and writes it to the residual stream at each token (the "nesting depth"). A subsequent attention head reads out the nesting depth using a query channel, and thresholds it to only activate inside nested lists. This circuit uses 7 nodes and 4 edges. Understanding this algorithm allows us to adversarially attack the model with "distractors."
  • ...and 36 more figures