Reversed Attention: On The Gradient Descent Of Attention Layers In GPT
Shahar Katz, Lior Wolf
TL;DR
This work reveals Reversed Attention (RA), the backward-pass counterpart to forward attention in Transformer-based models, as the implicit attention map produced by the softmax derivative during backpropagation. By deriving Vector-Jacobian Products and the gradient flow for $W_q$, $W_k$, $W_v$, and $W_o$, the authors show RA forms a sparse, lower-triangular structure that mirrors but distinct from forward attention, providing a lens into gradient-driven editing dynamics. They introduce attention patching, a method to inject RA-informed attention into the forward pass to alter predictions without updating weights, and demonstrate competitive performance with causal mediation while offering faster, interpretable insights. Across GPT2/OPT/Llama2 variants and a suite of tasks, RA serves as a practical tool for identifying influential heads and steering model behavior, albeit within the scope of decoder-only architectures and frame-specific patching.
Abstract
The success of Transformer-based Language Models (LMs) stems from their attention mechanism. While this mechanism has been extensively studied in explainability research, particularly through the attention values obtained during the forward pass of LMs, the backward pass of attention has been largely overlooked. In this work, we study the mathematics of the backward pass of attention, revealing that it implicitly calculates an attention matrix we refer to as "Reversed Attention". We examine the properties of Reversed Attention and demonstrate its ability to elucidate the models' behavior and edit dynamics. In an experimental setup, we showcase the ability of Reversed Attention to directly alter the forward pass of attention, without modifying the model's weights, using a novel method called "attention patching". In addition to enhancing the comprehension of how LM configure attention layers during backpropagation, Reversed Attention maps contribute to a more interpretable backward pass.
