Table of Contents
Fetching ...

Interpretability Illusions in the Generalization of Simplified Models

Dan Friedman, Andrew Lampinen, Lucas Dixon, Danqi Chen, Asma Ghandeharioun

TL;DR

This work interrogates the reliability of interpretability methods that rely on simplifying Transformer representations, showing that proxies like PCA/SVD, clustering, or one-hot attention can faithfully mimic in-distribution behavior yet fail to capture out-of-distribution generalization on tasks such as Dyck-language bracket matching and code completion. By training small two-layer Transformers and applying multiple proxy strategies, the authors reveal consistent generalization gaps where simplified proxies underestimate or mispredict systematic generalization, including depth and repeated-word phenomena linked to induction-head mechanisms. The findings highlight a fundamental limitation of data-dependent simplifications for understanding model computations and stress the need for out-of-distribution evaluation in interpretability research, while drawing connections to compression, neuroscience, and the complexity-generalization literature. Overall, the work cautions against overreliance on low-rank or discrete proxies for mechanistic explanations and motivates developing more faithful interpretability methods that remain valid across distribution shifts.

Abstract

A common method to study deep learning systems is to use simplified model representations--for example, using singular value decomposition to visualize the model's hidden states in a lower dimensional space. This approach assumes that the results of these simplifications are faithful to the original model. Here, we illustrate an important caveat to this assumption: even if the simplified representations can accurately approximate the full model on the training set, they may fail to accurately capture the model's behavior out of distribution. We illustrate this by training Transformer models on controlled datasets with systematic generalization splits, including the Dyck balanced-parenthesis languages and a code completion task. We simplify these models using tools like dimensionality reduction and clustering, and then explicitly test how these simplified proxies match the behavior of the original model. We find consistent generalization gaps: cases in which the simplified proxies are more faithful to the original model on the in-distribution evaluations and less faithful on various tests of systematic generalization. This includes cases where the original model generalizes systematically but the simplified proxies fail, and cases where the simplified proxies generalize better. Together, our results raise questions about the extent to which mechanistic interpretations derived using tools like SVD can reliably predict what a model will do in novel situations.

Interpretability Illusions in the Generalization of Simplified Models

TL;DR

This work interrogates the reliability of interpretability methods that rely on simplifying Transformer representations, showing that proxies like PCA/SVD, clustering, or one-hot attention can faithfully mimic in-distribution behavior yet fail to capture out-of-distribution generalization on tasks such as Dyck-language bracket matching and code completion. By training small two-layer Transformers and applying multiple proxy strategies, the authors reveal consistent generalization gaps where simplified proxies underestimate or mispredict systematic generalization, including depth and repeated-word phenomena linked to induction-head mechanisms. The findings highlight a fundamental limitation of data-dependent simplifications for understanding model computations and stress the need for out-of-distribution evaluation in interpretability research, while drawing connections to compression, neuroscience, and the complexity-generalization literature. Overall, the work cautions against overreliance on low-rank or discrete proxies for mechanistic explanations and motivates developing more faithful interpretability methods that remain valid across distribution shifts.

Abstract

A common method to study deep learning systems is to use simplified model representations--for example, using singular value decomposition to visualize the model's hidden states in a lower dimensional space. This approach assumes that the results of these simplifications are faithful to the original model. Here, we illustrate an important caveat to this assumption: even if the simplified representations can accurately approximate the full model on the training set, they may fail to accurately capture the model's behavior out of distribution. We illustrate this by training Transformer models on controlled datasets with systematic generalization splits, including the Dyck balanced-parenthesis languages and a code completion task. We simplify these models using tools like dimensionality reduction and clustering, and then explicitly test how these simplified proxies match the behavior of the original model. We find consistent generalization gaps: cases in which the simplified proxies are more faithful to the original model on the in-distribution evaluations and less faithful on various tests of systematic generalization. This includes cases where the original model generalizes systematically but the simplified proxies fail, and cases where the simplified proxies generalize better. Together, our results raise questions about the extent to which mechanistic interpretations derived using tools like SVD can reliably predict what a model will do in novel situations.
Paper Structure (51 sections, 2 equations, 20 figures, 2 tables)

This paper contains 51 sections, 2 equations, 20 figures, 2 tables.

Figures (20)

  • Figure 1: Accuracy at predicting closing brackets over the course of training, averaged over three random seeds.
  • Figure 2: Second-layer attention embeddings for Dyck sequences, projected onto the first and third singular vectors and colored by bracket depth. The maximum depth during training is 10. At each position, the model can find the most recent unmatched bracket by finding the most recent bracket at the current nesting depth.
  • Figure 3: Approximation quality after applying two simplifications to the key and query embeddings, clustering (left) and SVD (right). JSD is the average Jensen-Shannon Divergence between the attention patterns of the original and simplified models, and Same Prediction measures whether the two models make the same prediction at the final layer. Plots show the mean and 95% CI after applying the simplification to models trained with three random seeds. For both methods, the approximation quality is better on the in-distribution evaluation set (Seen struct) and worse on examples with unseen structures or nesting depths.
  • Figure 4: Approximation quality and accuracy after replacing the second-layer attention pattern with a one-hot attention pattern that assigns all attention to the highest scoring key, averaged over three models trained with different random seeds. One-hot attention is a faithful approximation on all generalization splits except for the depth generalization split (Fig. \ref{['fig:one_hot_attention_approximation_scores']}). This difference illustrates that a simplification which is faithful in some out-of-distribution evaluations may fail in others. In depth generalization, the one-hot attention model slightly out-performs the original model (Fig. \ref{['fig:one_hot_attention_errors']})---particularly at higher depths and in cases where the original model attends to the correct location---thereby over-estimating how well the original model will generalize.
  • Figure 5: Errors of the original model and a rank-8 SVD simplification on the depth generalization test set. Figures \ref{['fig:depth_generalization_prediction_errors']} and \ref{['fig:depth_generalization_prediction_errors_pca8']} plot the prediction accuracy, broken down by the depth of the query and the maximum depth among the keys. Figures \ref{['fig:depth_generalization_attention_errors']} and \ref{['fig:depth_generalization_attention_errors_pca8']} plot the depth of the token with the highest attention score, broken down by the true target depth, considering only incorrect predictions. The models have similar error patterns on shallower depths, attending to depths two higher or two lower than the target depth, but the simplified model diverges on depths greater than ten.
  • ...and 15 more figures