Table of Contents
Fetching ...

Partially Rewriting a Transformer in Natural Language

Gonçalo Paulo, Nora Belrose

TL;DR

The paper tackles mechanistic interpretability by attempting to partially rewrite a transformer using natural language explanations to define interpretable latent features. It trains a sparse transcoder to approximate a layer's MLP and uses an LLM-based simulator, guided by explanations, to predict neuron activations, with quantile normalization calibrating the predictions. Evaluation shows that the loss increase from these substitutions is close to zero-vector ablation, indicating that current explanations are not precise enough to preserve performance. The work highlights the need for more detailed, contrastive, and calibrated explanations to advance faithful latent-level rewrites in large language models.

Abstract

The greatest ambition of mechanistic interpretability is to completely rewrite deep neural networks in a format that is more amenable to human understanding, while preserving their behavior and performance. In this paper, we attempt to partially rewrite a large language model using simple natural language explanations. We first approximate one of the feedforward networks in the LLM with a wider MLP with sparsely activating neurons - a transcoder - and use an automated interpretability pipeline to generate explanations for these neurons. We then replace the first layer of this sparse MLP with an LLM-based simulator, which predicts the activation of each neuron given its explanation and the surrounding context. Finally, we measure the degree to which these modifications distort the model's final output. With our pipeline, the model's increase in loss is statistically similar to entirely replacing the sparse MLP output with the zero vector. We employ the same protocol, this time using a sparse autoencoder, on the residual stream of the same layer and obtain similar results. These results suggest that more detailed explanations are needed to improve performance substantially above the zero ablation baseline.

Partially Rewriting a Transformer in Natural Language

TL;DR

The paper tackles mechanistic interpretability by attempting to partially rewrite a transformer using natural language explanations to define interpretable latent features. It trains a sparse transcoder to approximate a layer's MLP and uses an LLM-based simulator, guided by explanations, to predict neuron activations, with quantile normalization calibrating the predictions. Evaluation shows that the loss increase from these substitutions is close to zero-vector ablation, indicating that current explanations are not precise enough to preserve performance. The work highlights the need for more detailed, contrastive, and calibrated explanations to advance faithful latent-level rewrites in large language models.

Abstract

The greatest ambition of mechanistic interpretability is to completely rewrite deep neural networks in a format that is more amenable to human understanding, while preserving their behavior and performance. In this paper, we attempt to partially rewrite a large language model using simple natural language explanations. We first approximate one of the feedforward networks in the LLM with a wider MLP with sparsely activating neurons - a transcoder - and use an automated interpretability pipeline to generate explanations for these neurons. We then replace the first layer of this sparse MLP with an LLM-based simulator, which predicts the activation of each neuron given its explanation and the surrounding context. Finally, we measure the degree to which these modifications distort the model's final output. With our pipeline, the model's increase in loss is statistically similar to entirely replacing the sparse MLP output with the zero vector. We employ the same protocol, this time using a sparse autoencoder, on the residual stream of the same layer and obtain similar results. These results suggest that more detailed explanations are needed to improve performance substantially above the zero ablation baseline.

Paper Structure

This paper contains 12 sections, 1 equation, 9 figures.

Figures (9)

  • Figure 1: Distribution of predicted activations for all latents. On the left we compare the distribution of predicted activations before normalization, and on the right we show what the distribution looks like after quantile normalization. Before normalization, the predictor model systematically over-predicts high activation values by multiple orders of magnitude. Quantile normalization primarily has the effect of enforcing a prior in favor of features not being active.
  • Figure 2: Partially rewriting an LLM. After training a Transcoder, or any type of SAE, we generate explanations for all the latents using the contexts where that latent is active. An LLM is tasked to summarize or otherwise find patterns in the activations and output a simple, single sentence explanation for that latent. These explanations are used by another instance of an LLM to predict wether the latent should be active in a given token. After some post-processing of those predictions, a reconstruction vector is calculated using the decoder directions of the latents that are considered to be active for that token.
  • Figure 3: Cross entropy loss increase for different fractions of transcoder and SAE substitution. We compute the CE loss over 10K prompts, for the transcoder (left) and SAE (right) respectively, by substituting parts of the encoder with natural language explanations. Bars in green show the average loss increase when choosing the top scoring latents for replacement. Bars in orange show the average loss increase when randomly selecting a subset of latents to replace. Bars in blue show the average loss increase caused by zeroing out a part of the transcoder. Bar heights represent the median value of the absolute difference, because the distribution is heavy-tailed, and error bars are $95\%$ confidence intervals computed using bootstrapping. The interpretability score used for the selecting latents is detection scoring, paulo2024automatically, computed over 100 positive and 100 negative samples. Over this set of prompts, Pythia had a cross entropy loss of $3.19 \pm 0.09$ nats per token.
  • Figure 4: Detection score predicts sensitivity and specificity. Binning explanations by their scores makes it evident that high-scoring explanations are more specific and sensitive.
  • Figure A1: Distribution of predicted activations for all latents over a smaller sample size If we use only 1K prompts as the predicted activations and use the 10M prompts as the target distribution, the mismatch with the empirical activation distribution is higher.
  • ...and 4 more figures