Table of Contents
Fetching ...

Learning to Attribute with Attention

Benjamin Cohen-Wang, Yung-Sung Chuang, Aleksander Madry

TL;DR

Token attribution for language model generations is computationally expensive when relying on ablations. The paper introduces AT2, which predicts the impact of ablating sources by learning coefficients for individual attention heads, combining attention signals into a generalizable surrogate across examples. AT2 achieves attribution quality on par with extensive ablation-based methods while offering substantial efficiency gains, and it demonstrates practical utility by improving HotpotQA performance through context pruning. By leveraging head-wise attention signals in a learned, generalized framework, AT2 enables scalable, faithful explanations of how preceding tokens influenceLM generations.

Abstract

Given a sequence of tokens generated by a language model, we may want to identify the preceding tokens that influence the model to generate this sequence. Performing such token attribution is expensive; a common approach is to ablate preceding tokens and directly measure their effects. To reduce the cost of token attribution, we revisit attention weights as a heuristic for how a language model uses previous tokens. Naive approaches to attribute model behavior with attention (e.g., averaging attention weights across attention heads to estimate a token's influence) have been found to be unreliable. To attain faithful attributions, we propose treating the attention weights of different attention heads as features. This way, we can learn how to effectively leverage attention weights for attribution (using signal from ablations). Our resulting method, Attribution with Attention (AT2), reliably performs on par with approaches that involve many ablations, while being significantly more efficient. To showcase the utility of AT2, we use it to prune less important parts of a provided context in a question answering setting, improving answer quality. We provide code for AT2 at https://github.com/MadryLab/AT2 .

Learning to Attribute with Attention

TL;DR

Token attribution for language model generations is computationally expensive when relying on ablations. The paper introduces AT2, which predicts the impact of ablating sources by learning coefficients for individual attention heads, combining attention signals into a generalizable surrogate across examples. AT2 achieves attribution quality on par with extensive ablation-based methods while offering substantial efficiency gains, and it demonstrates practical utility by improving HotpotQA performance through context pruning. By leveraging head-wise attention signals in a learned, generalized framework, AT2 enables scalable, faithful explanations of how preceding tokens influenceLM generations.

Abstract

Given a sequence of tokens generated by a language model, we may want to identify the preceding tokens that influence the model to generate this sequence. Performing such token attribution is expensive; a common approach is to ablate preceding tokens and directly measure their effects. To reduce the cost of token attribution, we revisit attention weights as a heuristic for how a language model uses previous tokens. Naive approaches to attribute model behavior with attention (e.g., averaging attention weights across attention heads to estimate a token's influence) have been found to be unreliable. To attain faithful attributions, we propose treating the attention weights of different attention heads as features. This way, we can learn how to effectively leverage attention weights for attribution (using signal from ablations). Our resulting method, Attribution with Attention (AT2), reliably performs on par with approaches that involve many ablations, while being significantly more efficient. To showcase the utility of AT2, we use it to prune less important parts of a provided context in a question answering setting, improving answer quality. We provide code for AT2 at https://github.com/MadryLab/AT2 .

Paper Structure

This paper contains 44 sections, 8 equations, 5 figures, 1 algorithm.

Figures (5)

  • Figure 1: Attention heads vary in their usefulness for attribution. When visualizing attention weights of three individual heads, we observe that certain heads appear to be more useful for attribution than others. In particular, layer #19, head #4 assigns high attention weights to "27", "21" and "25" which are the field goal distances mentioned in the generation of interest. Meanwhile, layer #6, head #22 and layer #25, head #10 assign high attention weights to other, seemingly unrelated parts of the context. This example is from DROP dua2019drop with a generation from Phi-3.5-miniabdin2024phi. Attention weights are averaged across the generation of interest and normalized by dividing by the maximum weight for each head.
  • Figure 2: AT2 identifies attention heads that appear useful for attribution. The attention head with the highest coefficient (as learned by AT2) attends to tokens that seem very relevant to the generation of interest. On the other hand, the head with the lowest-magnitude coefficient attends to other, seemingly unrelated parts of the context. To learn these coefficients, AT2 is trained on Dolly 15kDatabricksBlog2023DollyV2 (see \ref{['sec:evaluations']} for details). The details are otherwise identical to \ref{['fig:attention_heads']}. See \ref{['sec:coefficients']} for the coefficients themselves.
  • Figure 3: Qualitative comparison of token attributions. We visualize the attribution scores (blue) of different methods for a particular generated statement (yellow) in a context attribution setting for Phi-3.5-mini (with individual tokens as sources). AT2 (trained on a generic dataset) and ESM (with $256$ ablations) yield similar attributions with high scores for tokens related to the generated statement of interest, while average attention assigns the highest score to a seemingly arbitrary token. See \ref{['sec:attribution_examples']} for additional examples of attributions.
  • Figure 4: Evaluating token attributions. We report performance metrics for different attribution methods in context attribution (\ref{['fig:context_attribution']}) and thought attribution (\ref{['fig:thought_attribution']}) settings. AT2 performs similarly when trained on examples from the task of interest (task-specific) or when trained on examples from a generic task (general). It consistently outperforms the gradient and average attention baselines and performs comparably to example-specific surrogate modeling (ESM) with a substantial number of ablations. In \ref{['fig:efficiency_comparison']}, we compare the running time of AT2 to a single forward pass (ESM uses $\geq32$) and a single backward pass (which the gradient method requires). See \ref{['sec:additional_evaluations']} for additional evaluations.
  • Figure 5: Improving response quality by pruning the context. Using attributions from AT2 to prune away less important parts of the context improves response quality on HotpotQA. We compare AT2 to example-specific surrogate modeling (ESM) with $32$ ablations and average attention for different numbers of retained passages $k$, finding that AT2 improves performance more while being less costly than a single inference pass. "Baseline" denotes the performance of the model without context pruning, while "Oracle" denotes the performance of the model when provided only with the ground-truth sources. The F1 score is averaged over $4,000$ examples from the HotpotQA validation set.

Theorems & Definitions (3)

  • Definition 2.1: Token attribution method.
  • Definition 2.2: Top-$k$ drop
  • Definition 2.3: Linear datamodeling score