Table of Contents
Fetching ...

HyperDAS: Towards Automating Mechanistic Interpretability with Hypernetworks

Jiuding Sun, Jing Huang, Sidharth Baskaran, Karel D'Oosterlinck, Christopher Potts, Michael Sklar, Atticus Geiger

TL;DR

HyperDAS automates mechanistic interpretability by using a transformer-based hypernetwork to locate where a target concept is realized in a model's residual stream and to identify a low-rank subspace of features mediating that concept. It couples concept encoding, dynamic token-position selection, and a Householder-based subspace rotation to implement distributed interchange interventions, trained with a RAVEL-specific loss plus a sparsity penalty. On Llama3-8B, HyperDAS achieves state-of-the-art disentanglement on the RAVEL benchmark while providing careful discussion of fidelity, potential pitfalls, and computational trade-offs relative to prior approaches like MDAS. The work advances scalable, end-to-end automated interpretability for large language models, and its analysis via Householder vectors offers insight into attribute-specific subspaces, informing future robustness and reliability considerations in mechanistic explanations.

Abstract

Mechanistic interpretability has made great strides in identifying neural network features (e.g., directions in hidden activation space) that mediate concepts(e.g., the birth year of a person) and enable predictable manipulation. Distributed alignment search (DAS) leverages supervision from counterfactual data to learn concept features within hidden states, but DAS assumes we can afford to conduct a brute force search over potential feature locations. To address this, we present HyperDAS, a transformer-based hypernetwork architecture that (1) automatically locates the token-positions of the residual stream that a concept is realized in and (2) constructs features of those residual stream vectors for the concept. In experiments with Llama3-8B, HyperDAS achieves state-of-the-art performance on the RAVEL benchmark for disentangling concepts in hidden states. In addition, we review the design decisions we made to mitigate the concern that HyperDAS (like all powerful interpretabilty methods) might inject new information into the target model rather than faithfully interpreting it.

HyperDAS: Towards Automating Mechanistic Interpretability with Hypernetworks

TL;DR

HyperDAS automates mechanistic interpretability by using a transformer-based hypernetwork to locate where a target concept is realized in a model's residual stream and to identify a low-rank subspace of features mediating that concept. It couples concept encoding, dynamic token-position selection, and a Householder-based subspace rotation to implement distributed interchange interventions, trained with a RAVEL-specific loss plus a sparsity penalty. On Llama3-8B, HyperDAS achieves state-of-the-art disentanglement on the RAVEL benchmark while providing careful discussion of fidelity, potential pitfalls, and computational trade-offs relative to prior approaches like MDAS. The work advances scalable, end-to-end automated interpretability for large language models, and its analysis via Householder vectors offers insight into attribute-specific subspaces, informing future robustness and reliability considerations in mechanistic explanations.

Abstract

Mechanistic interpretability has made great strides in identifying neural network features (e.g., directions in hidden activation space) that mediate concepts(e.g., the birth year of a person) and enable predictable manipulation. Distributed alignment search (DAS) leverages supervision from counterfactual data to learn concept features within hidden states, but DAS assumes we can afford to conduct a brute force search over potential feature locations. To address this, we present HyperDAS, a transformer-based hypernetwork architecture that (1) automatically locates the token-positions of the residual stream that a concept is realized in and (2) constructs features of those residual stream vectors for the concept. In experiments with Llama3-8B, HyperDAS achieves state-of-the-art performance on the RAVEL benchmark for disentangling concepts in hidden states. In addition, we review the design decisions we made to mitigate the concern that HyperDAS (like all powerful interpretabilty methods) might inject new information into the target model rather than faithfully interpreting it.

Paper Structure

This paper contains 46 sections, 16 equations, 13 figures, 1 table.

Figures (13)

  • Figure 1: The HyperDAS framework, used here to find the features that mediate the concept of "country". (1) Concept Encoding A natural language description that specifies which concept to localize, "The country of a city", is encoded by a transformer hypernetwork with two additional cross-attention blocks attending to the hidden states of the target LM prompted with the base text "Vienna is in" and the counterfactual text "I love Paris". (2) Selecting Token-Positions With the encoding from step 1 as a query, HyperDAS uses selects the tokens "nna" and "Paris" as the location of the concept "country" for the base and counterfactual, respectively. (3) Identifying a Subspace With the representation from step 1 as the encoding, HyperDAS constructs a matrix whose orthogonal columns are the features for "country". (4) Interchange Intervention With the token-positions from step 2 and subspace from step 3, HyperDAS performs a intervention by patching the subspace of the hidden vector for the token "nna" to the value it takes on in the hidden vector for the token "Paris", leading the model to predict "France" from the base prompt "Vienna is in".
  • Figure 2: The "intervention score" matrix $G$ for the counterfactual prompt "The city of Macheng's official language is" and base prompt "Springfield is a city in the country of", for which the target model will output "China". The attribute targeted for intervention is country, so the output should be "The United States". The raw intervention (left) is produced by the token-position selection discussed in Sec \ref{['sec:token-select']}, and a column-wise and row-wise argmax is applied at inference time to enforce an 1-1 correspondence between the base and counterfactual tokens, detailed in Sec. \ref{['sec:eval']}.
  • Figure 3: RAVEL benchmark results. HyperDAS establishes a new state-of-the-art.
  • Figure 4: The intervention location, in counterfactual and base sentence, picked by HyperDAS when targeting shallow (7), middle (15) and deep (29) decoder layers.
  • Figure 5: The relative position between the Householder vector (after PCA) of attributes for all the correct predictions in city domain. The clustering indicates that HyperDAS learns different subspace for each attribute.
  • ...and 8 more figures