Table of Contents
Fetching ...

Sparse Attention Post-Training for Mechanistic Interpretability

Florent Draye, Anson Lei, Ingmar Posner, Bernhard Schölkopf

TL;DR

The paper tackles the interpretability bottleneck of large transformers by introducing a post-training sparsification method for attention using a SPARTAN-based Sparse Transformer and GECO-constrained optimisation to preserve loss. It demonstrates that attention can be pared down to about 0.3% of edges without performance loss, while yielding dramatically simpler, more modular circuits revealed by mechanistic interpretability analyses. Activation patching shows sparse models need far fewer heads and edges to reproduce behavior, indicating a concentrated, task-relevant computation backbone. This work suggests sparsity as a practical prior for building more structured and interpretable transformer models without sacrificing capability.

Abstract

We introduce a simple post-training method that makes transformer attention sparse without sacrificing performance. Applying a flexible sparsity regularisation under a constrained-loss objective, we show on models up to 1B parameters that it is possible to retain the original pretraining loss while reducing attention connectivity to $\approx 0.3 \%$ of its edges. Unlike sparse-attention methods designed for computational efficiency, our approach leverages sparsity as a structural prior: it preserves capability while exposing a more organized and interpretable connectivity pattern. We find that this local sparsity cascades into global circuit simplification: task-specific circuits involve far fewer components (attention heads and MLPs) with up to 100x fewer edges connecting them. These results demonstrate that transformer attention can be made orders of magnitude sparser, suggesting that much of its computation is redundant and that sparsity may serve as a guiding principle for more structured and interpretable models.

Sparse Attention Post-Training for Mechanistic Interpretability

TL;DR

The paper tackles the interpretability bottleneck of large transformers by introducing a post-training sparsification method for attention using a SPARTAN-based Sparse Transformer and GECO-constrained optimisation to preserve loss. It demonstrates that attention can be pared down to about 0.3% of edges without performance loss, while yielding dramatically simpler, more modular circuits revealed by mechanistic interpretability analyses. Activation patching shows sparse models need far fewer heads and edges to reproduce behavior, indicating a concentrated, task-relevant computation backbone. This work suggests sparsity as a practical prior for building more structured and interpretable transformer models without sacrificing capability.

Abstract

We introduce a simple post-training method that makes transformer attention sparse without sacrificing performance. Applying a flexible sparsity regularisation under a constrained-loss objective, we show on models up to 1B parameters that it is possible to retain the original pretraining loss while reducing attention connectivity to of its edges. Unlike sparse-attention methods designed for computational efficiency, our approach leverages sparsity as a structural prior: it preserves capability while exposing a more organized and interpretable connectivity pattern. We find that this local sparsity cascades into global circuit simplification: task-specific circuits involve far fewer components (attention heads and MLPs) with up to 100x fewer edges connecting them. These results demonstrate that transformer attention can be made orders of magnitude sparser, suggesting that much of its computation is redundant and that sparsity may serve as a guiding principle for more structured and interpretable models.

Paper Structure

This paper contains 16 sections, 5 equations, 6 figures.

Figures (6)

  • Figure 1: Simple example showing the attention patterns (shown in blue) of sparse and non-sparse transformers trained on a two digit addition task. Both models are able to correctly predict the sum, but the attention patterns are very different: the non-sparse model solves the task with highly dispersed information flow, while the sparse model uses a highly interpretable attention pattern: in Layer 0, the model first attends to the corresponding digits to be added, then in Layer 1, it attends to the carry bit only if it is needed (see middle and right columns, where the model has to carry once and twice respectively).
  • Figure 2: (a) Cross-entropy loss with respect to the mean proportion of non-zero attention edges across the evaluation dataset. (b) Example of attention patterns in GPT-2 compared with those in a sparse-attention GPT-2 model for a single IOI task instance.
  • Figure 3: (a) Logit attribution per sentence keeping only the top-$k$ attention heads selected individually for each sentence via activation patching scores. (b) Logit attribution using a fixed set of top-$k$ attention heads selected globally across all sentences by averaging activation patching scores.
  • Figure 4: (a) Cumulative distribution of the sorted attribution-patching edge scores. The dotted lines show the cumulative distribution of the mean scores across sentences, reflecting circuit stability. The solid lines show the mean cumulative distribution of the scores across sentences. The plots for LLaMA are included in Appendix \ref{['app:additional_results']} (b) Example of the attention-head edges required to reach 0.9 cumulative score based on the averaged scores for the IOI task, representing the general circuit across all sentences. The comparison is shown for GPT-2 and its sparse-attention variant.
  • Figure 5: Cumulative distribution of the sorted attribution-patching edge scores. The dotted lines show the cumulative distribution of the mean scores across sentences, reflecting circuit stability. The solid lines show the mean cumulative distribution of the scores across sentences.
  • ...and 1 more figures