Table of Contents
Fetching ...

Start Making Sense(s): A Developmental Probe of Attention Specialization Using Lexical Ambiguity

Pamela D. Rivière, Sean Trott

TL;DR

This work investigates how Transformer-style self-attention heads develop specialized roles in lexical disambiguation by tracking developmental trajectories across Pythia checkpoints. It combines psycholinguistic stimuli (RAW-C) with targeted QK perturbations and causal ablations to identify heads that covary with disambiguation performance, and then assesses their robustness to stimulus perturbations. The study finds early developmental milestones where attention to disambiguation cues increases, with larger models showing more robust and generalizable disambiguation heads; ablations confirm causal contributions, especially in smaller models. The results highlight a developmental perspective as a powerful lens to understand contextualization mechanisms and raise questions about generalization across seeds and model scales.

Abstract

Despite an in-principle understanding of self-attention matrix operations in Transformer language models (LMs), it remains unclear precisely how these operations map onto interpretable computations or functions--and how or when individual attention heads develop specialized attention patterns. Here, we present a pipeline to systematically probe attention mechanisms, and we illustrate its value by leveraging lexical ambiguity--where a single word has multiple meanings--to isolate attention mechanisms that contribute to word sense disambiguation. We take a "developmental" approach: first, using publicly available Pythia LM checkpoints, we identify inflection points in disambiguation performance for each LM in the suite; in 14M and 410M, we identify heads whose attention to disambiguating words covaries with overall disambiguation performance across development. We then stress-test the robustness of these heads to stimulus perturbations: in 14M, we find limited robustness, but in 410M, we identify multiple heads with surprisingly generalizable behavior. Then, in a causal analysis, we find that ablating the target heads demonstrably impairs disambiguation performance, particularly in 14M. We additionally reproduce developmental analyses of 14M across all of its random seeds. Together, these results suggest: that disambiguation benefits from a constellation of mechanisms, some of which (especially in 14M) are highly sensitive to the position and part-of-speech of the disambiguating cue; and that larger models (410M) may contain heads with more robust disambiguation behavior. They also join a growing body of work that highlights the value of adopting a developmental perspective when probing LM mechanisms.

Start Making Sense(s): A Developmental Probe of Attention Specialization Using Lexical Ambiguity

TL;DR

This work investigates how Transformer-style self-attention heads develop specialized roles in lexical disambiguation by tracking developmental trajectories across Pythia checkpoints. It combines psycholinguistic stimuli (RAW-C) with targeted QK perturbations and causal ablations to identify heads that covary with disambiguation performance, and then assesses their robustness to stimulus perturbations. The study finds early developmental milestones where attention to disambiguation cues increases, with larger models showing more robust and generalizable disambiguation heads; ablations confirm causal contributions, especially in smaller models. The results highlight a developmental perspective as a powerful lens to understand contextualization mechanisms and raise questions about generalization across seeds and model scales.

Abstract

Despite an in-principle understanding of self-attention matrix operations in Transformer language models (LMs), it remains unclear precisely how these operations map onto interpretable computations or functions--and how or when individual attention heads develop specialized attention patterns. Here, we present a pipeline to systematically probe attention mechanisms, and we illustrate its value by leveraging lexical ambiguity--where a single word has multiple meanings--to isolate attention mechanisms that contribute to word sense disambiguation. We take a "developmental" approach: first, using publicly available Pythia LM checkpoints, we identify inflection points in disambiguation performance for each LM in the suite; in 14M and 410M, we identify heads whose attention to disambiguating words covaries with overall disambiguation performance across development. We then stress-test the robustness of these heads to stimulus perturbations: in 14M, we find limited robustness, but in 410M, we identify multiple heads with surprisingly generalizable behavior. Then, in a causal analysis, we find that ablating the target heads demonstrably impairs disambiguation performance, particularly in 14M. We additionally reproduce developmental analyses of 14M across all of its random seeds. Together, these results suggest: that disambiguation benefits from a constellation of mechanisms, some of which (especially in 14M) are highly sensitive to the position and part-of-speech of the disambiguating cue; and that larger models (410M) may contain heads with more robust disambiguation behavior. They also join a growing body of work that highlights the value of adopting a developmental perspective when probing LM mechanisms.

Paper Structure

This paper contains 33 sections, 1 equation, 11 figures.

Figures (11)

  • Figure 1: Disambiguation performance at final checkpoint for Pythia language models (LMs).(a) Sample RAW-C sentences, evoking different senses for the target ambiguous word (lamb) with single differing disambiguating word (marinated, friendly). We obtain: the AttentionScore from the ambiguous word to disambiguating word, per sentence; the cosine distance between contextualized representations for the target ambiguous word across sentences in a pair; and the publicly available human relatedness judgments between the target ambiguous word across sentences in a pair. (b) Max $R^2$ obtained from the final checkpoint of nine Pythia LMs, by number of parameters. Arrows mark models of interest, -$14M$ and -$410M$. Horizontal dashed line represents mean human interannotator agreement. (c) Each subpanel shows the head index by layer index for a given LM; warmer colors indicate higher z-scored mean attention scores to disambiguating words, for the final checkpoint of each LM.
  • Figure 2: Identifying candidate attention heads.(a) Maximum $R^2$ from the final step, by layer depth (current layer$/$max number of layers), for nine Pythia LMs. (b) "Developmental" view of $R^2$, obtained from each training step for nine Pythias. Depicted $R^2$s are from the layer with the max $R^2$ at the final checkpoint (e.g. for $14M$, $R^2$ is from Layer 3; for $410M$, $R^2$ is from Layer 24). (c) "Developmental" view of attention to disambiguating word, for all head indices in $14M$'s Layer 3. Superimposed is the $R^2$ from Layer 3. Attention scores for Heads $(3,1)$ and $(3,2)$ covary with disambiguation performance. (d) Same as in c, but for select layers in $410M$. Layers were selected if they contained at least one head whose attention scores rose during training. Superimposed $R^2$ curves were drawn from each Layer depicted. Vertical dashed line in sub-panels b-d mark training step $1000$, corresponding to $2.1M$ tokens seen cumulatively over training.
  • Figure 3: Stress testing Pythia-14M candidate heads' attention to disambiguating word. Attention Head color coding scheme applies to all panels. (a) Layer 3 difference in average attention scores for disambiguating word against that of all 1-back tokens, over pre-training. Ticks indicate training steps with significant difference in attention scores (p < 0.05), adjusted for multiple comparisons. Only Head (3,2) remains significant at the final training checkpoint. (b) Each sub-panel corresponds to a different layer. (top) Attention to disambiguating word when it is separated from the target ambiguous word via inserted string. (bottom) Attention to last token of inserted string. The square surrounding Layer 3 highlights the attentional robustness of at least one of the two candidate heads. (c) Attention to disambiguating word when its part-of-speech changes to a verb (top) or a noun (bottom).
  • Figure 4: Stress testing Pythia-$410M$. Warmer colors indicate larger Disambiguation Composite scores for Pythia-410M heads and layers. Larger scores reflect greater attentional robustness to the disambiguating word despite stimulus perturbations, and greater degree of attention covariance with disambiguation performance throughout pre-training.
  • Figure 5: Target head ablations decrease disambiguation performance relative to intact models.(a) Mean difference in $R^2$ across all layers and all combinations of Pythia-$14M$'s target head ablations, for all training steps. Values $>0$ indicate that the intact model's $R^2$ exceeded that of the ablated model's $R^2$, reflecting causal effect of ablation. Target manipulations refer to zero-ablations of previously-identified Layer 3 heads. Baseline manipulations refer to zero-ablations of Layer 3 heads whose attention to disambiguating words fail to increase with disambiguation performance. (b) Same as in a, but parcelled out by layer, to illustrate localization of ablation effects, which remain robust throughout training in Layer 3. Dashed square marks the only layer (Layer 3) to suffer head ablations. (c) Same as in a, but for Pythia-$410M$. Target head ablations causally decrease model performance. (d) Same as in c, but parcelled out by early versus late layers, to illustrate the selectivity of target-head ablation to earlier layer representations. By the end of training, the effects of target-head ablations are indistinguishable from those of baseline-head ablations.
  • ...and 6 more figures