Table of Contents
Fetching ...

Identifying Sparsely Active Circuits Through Local Loss Landscape Decomposition

Brianna Chrisman, Lucius Bushnaq, Lee Sharkey

TL;DR

The paper tackles the challenge of interpreting large neural networks by shifting focus from activation-space to parameter-space and introducing Local Loss Landscape Decomposition (L3D). L3D learns low-rank subnetworks that reconstruct the gradient of divergence between a sample's output and a reference output, enabling circuit-level interpretability and targeted interventions. Through progressively complex toy models and preliminary real-world experiments on a transformer and a CNN, the method demonstrates the ability to recover interpretable subnetworks, localize their effects to relevant samples, and offer a pathway toward scalable circuit discovery in real models. This provides a principled framework for understanding and potentially guiding behavior in large models by manipulating compact, interpretable parameter directions.

Abstract

Much of mechanistic interpretability has focused on understanding the activation spaces of large neural networks. However, activation space-based approaches reveal little about the underlying circuitry used to compute features. To better understand the circuits employed by models, we introduce a new decomposition method called Local Loss Landscape Decomposition (L3D). L3D identifies a set of low-rank subnetworks: directions in parameter space of which a subset can reconstruct the gradient of the loss between any sample's output and a reference output vector. We design a series of progressively more challenging toy models with well-defined subnetworks and show that L3D can nearly perfectly recover the associated subnetworks. Additionally, we investigate the extent to which perturbing the model in the direction of a given subnetwork affects only the relevant subset of samples. Finally, we apply L3D to a real-world transformer model and a convolutional neural network, demonstrating its potential to identify interpretable and relevant circuits in parameter space.

Identifying Sparsely Active Circuits Through Local Loss Landscape Decomposition

TL;DR

The paper tackles the challenge of interpreting large neural networks by shifting focus from activation-space to parameter-space and introducing Local Loss Landscape Decomposition (L3D). L3D learns low-rank subnetworks that reconstruct the gradient of divergence between a sample's output and a reference output, enabling circuit-level interpretability and targeted interventions. Through progressively complex toy models and preliminary real-world experiments on a transformer and a CNN, the method demonstrates the ability to recover interpretable subnetworks, localize their effects to relevant samples, and offer a pathway toward scalable circuit discovery in real models. This provides a principled framework for understanding and potentially guiding behavior in large models by manipulating compact, interpretable parameter directions.

Abstract

Much of mechanistic interpretability has focused on understanding the activation spaces of large neural networks. However, activation space-based approaches reveal little about the underlying circuitry used to compute features. To better understand the circuits employed by models, we introduce a new decomposition method called Local Loss Landscape Decomposition (L3D). L3D identifies a set of low-rank subnetworks: directions in parameter space of which a subset can reconstruct the gradient of the loss between any sample's output and a reference output vector. We design a series of progressively more challenging toy models with well-defined subnetworks and show that L3D can nearly perfectly recover the associated subnetworks. Additionally, we investigate the extent to which perturbing the model in the direction of a given subnetwork affects only the relevant subset of samples. Finally, we apply L3D to a real-world transformer model and a convolutional neural network, demonstrating its potential to identify interpretable and relevant circuits in parameter space.

Paper Structure

This paper contains 49 sections, 11 equations, 18 figures, 1 table, 1 algorithm.

Figures (18)

  • Figure 1: Decomposing a loss landscape into a set of parameter directions, or subnetworks, where a smaller subset of directions can approximately reconstruct the gradient of divergence/loss between any sample's output and a reference output. Here, $D$ is a loss, or divergence measure, $f$ is our model, $W$ is the set of parameters in the model, $x_i$ is a sample input, and $y_r$ is a reference output
  • Figure 2: L3D subnetwork decomposition of TMS. Each subnetwork corresponds to the encoder/decoding of a single input feature.
  • Figure 3: The encoder/decoder directions of the original model (solid lines) and each subnetwork (dashed lines). The directions learned by each subnetwork are nearly perfectly parallel to the encoding for each input feature. The colors of the lines refer to the input index each embedding represents.
  • Figure 4: The effect of intervening on the TMS model in the direction of each subnetwork. We generated 1000 inputs from the TMS input distribution (x-axis), intervened on each subnetwork $v_i$ with magnitude $\delta$ and measured the change in outputs (y-axis) for each sample. The outputs corresponding to the index relevant to each subnetwork experienced the most change.
  • Figure 5: The effect of intervening at various values of $\delta$ in the direction of each subnetwork. The y-axis represents the average amount an output changed (data points colored by output index), when perturbed an amount $\delta$ in the direction of a subnetwork.
  • ...and 13 more figures