Table of Contents
Fetching ...

Causal Head Gating: A Framework for Interpreting Roles of Attention Heads in Transformers

Andrew Nam, Henry Conklin, Yukang Yang, Thomas Griffiths, Jonathan Cohen, Sarah-Jane Leslie

TL;DR

Causal Head Gating (CHG) presents a scalable, data-driven method for interpreting attention heads by learning soft gates over heads and assigning causal roles—facilitating, interfering, or irrelevant—based on their impact on next-token prediction. CHG uses a gating matrix $G \in [0,1]^{L \times H}$ and regularization to generate variation, producing two masks $G^+$ and $G^-$ that reveal the causal contribution of each head; it is validated against ablations and causal mediation analysis and extended with contrastive CHG to isolate sub-circuits for sub-tasks. Across LL models (Llama-3 variants) and diverse tasks (math, syntax, commonsense), CHG reveals sparse, distributed task-sufficient sub-circuits with low modularity and shows that instruction following and in-context learning rely on separable, context-dependent head mechanisms, as demonstrated by the CCHG variant. The method is lightweight, requires no task labels or prompt templates, and supports bootstrapped exploration of head configurations, offering a practical first-pass diagnostic that guides deeper mechanistic investigations and improves our understanding of distributed transformer computation.

Abstract

We present causal head gating (CHG), a scalable method for interpreting the functional roles of attention heads in transformer models. CHG learns soft gates over heads and assigns them a causal taxonomy - facilitating, interfering, or irrelevant - based on their impact on task performance. Unlike prior approaches in mechanistic interpretability, which are hypothesis-driven and require prompt templates or target labels, CHG applies directly to any dataset using standard next-token prediction. We evaluate CHG across multiple large language models (LLMs) in the Llama 3 model family and diverse tasks, including syntax, commonsense, and mathematical reasoning, and show that CHG scores yield causal, not merely correlational, insight validated via ablation and causal mediation analyses. We also introduce contrastive CHG, a variant that isolates sub-circuits for specific task components. Our findings reveal that LLMs contain multiple sparse task-sufficient sub-circuits, that individual head roles depend on interactions with others (low modularity), and that instruction following and in-context learning rely on separable mechanisms.

Causal Head Gating: A Framework for Interpreting Roles of Attention Heads in Transformers

TL;DR

Causal Head Gating (CHG) presents a scalable, data-driven method for interpreting attention heads by learning soft gates over heads and assigning causal roles—facilitating, interfering, or irrelevant—based on their impact on next-token prediction. CHG uses a gating matrix and regularization to generate variation, producing two masks and that reveal the causal contribution of each head; it is validated against ablations and causal mediation analysis and extended with contrastive CHG to isolate sub-circuits for sub-tasks. Across LL models (Llama-3 variants) and diverse tasks (math, syntax, commonsense), CHG reveals sparse, distributed task-sufficient sub-circuits with low modularity and shows that instruction following and in-context learning rely on separable, context-dependent head mechanisms, as demonstrated by the CCHG variant. The method is lightweight, requires no task labels or prompt templates, and supports bootstrapped exploration of head configurations, offering a practical first-pass diagnostic that guides deeper mechanistic investigations and improves our understanding of distributed transformer computation.

Abstract

We present causal head gating (CHG), a scalable method for interpreting the functional roles of attention heads in transformer models. CHG learns soft gates over heads and assigns them a causal taxonomy - facilitating, interfering, or irrelevant - based on their impact on task performance. Unlike prior approaches in mechanistic interpretability, which are hypothesis-driven and require prompt templates or target labels, CHG applies directly to any dataset using standard next-token prediction. We evaluate CHG across multiple large language models (LLMs) in the Llama 3 model family and diverse tasks, including syntax, commonsense, and mathematical reasoning, and show that CHG scores yield causal, not merely correlational, insight validated via ablation and causal mediation analyses. We also introduce contrastive CHG, a variant that isolates sub-circuits for specific task components. Our findings reveal that LLMs contain multiple sparse task-sufficient sub-circuits, that individual head roles depend on interactions with others (low modularity), and that instruction following and in-context learning rely on separable mechanisms.

Paper Structure

This paper contains 26 sections, 4 equations, 5 figures, 2 tables.

Figures (5)

  • Figure 1: (a) Schematic of a single multihead attention block with CHG-determined gating attenuation (in red). (b) Gate fitting trajectories for three heads on L3.2-3BI with OpenMathInstruct2. When fitting with $\lambda < 0$ and $\lambda > 0$, $G^+$ and $G^-$ both stay near 1 for facilitating heads and near 0 for interfering heads, but bifurcate to 1 and 0 respectively for irrelevant heads. (c) Gate values after fitting.
  • Figure 2: Difference in target log-probability when sequentially setting individual gates in $G^+$ to 1 and 0 in order of facilitation, irrelevance, and interference scores. The horizontal axis shows the number of heads ablated in descending score order. Positive values indicate task improvement, negative values indicate degradation, and values near zero indicate no effect. Note that not all heads in the top 50 necessarily have high absolute scores.
  • Figure 3: CHG score distributions and consistency. (a) Empirical cumulative distribution of CHG scores across all attention heads, showing the proportion of heads with scores below a given threshold for facilitation, irrelevance, and interference. (b) Aggregated CHG scores on L3.2-3BI, where red and green color channels represent interference ($1 - G^+$) and facilitation ($G^-$), respectively. Colors are combined using RGB rules: black indicates irrelevance (low in both), and yellow indicates both facilitation and interference (high in both). Always aggregates using the minimum across seeds (highlighting consistent effects); Any uses the maximum (highlighting any effect across seeds).
  • Figure 4: Task-facilitation scores versus (a) average indirect effect for function vector tasks and (b) CMA scores for symbolic reasoning tasks, showing significant heads by type (abstraction, induction, retrieval) and using the maximum CMA score across types for insignificant heads.
  • Figure 5: Task accuracy under CCHG. Columns indicate held-out evaluation tasks and rows indicate the retained prompt format. Bar color shows the evaluation prompt format. "Default" and "gated" indicate whether CCHG is applied during evaluation. Error bars indicate 95% CI.