Table of Contents
Fetching ...

Reversed Attention: On The Gradient Descent Of Attention Layers In GPT

Shahar Katz, Lior Wolf

TL;DR

This work reveals Reversed Attention (RA), the backward-pass counterpart to forward attention in Transformer-based models, as the implicit attention map produced by the softmax derivative during backpropagation. By deriving Vector-Jacobian Products and the gradient flow for $W_q$, $W_k$, $W_v$, and $W_o$, the authors show RA forms a sparse, lower-triangular structure that mirrors but distinct from forward attention, providing a lens into gradient-driven editing dynamics. They introduce attention patching, a method to inject RA-informed attention into the forward pass to alter predictions without updating weights, and demonstrate competitive performance with causal mediation while offering faster, interpretable insights. Across GPT2/OPT/Llama2 variants and a suite of tasks, RA serves as a practical tool for identifying influential heads and steering model behavior, albeit within the scope of decoder-only architectures and frame-specific patching.

Abstract

The success of Transformer-based Language Models (LMs) stems from their attention mechanism. While this mechanism has been extensively studied in explainability research, particularly through the attention values obtained during the forward pass of LMs, the backward pass of attention has been largely overlooked. In this work, we study the mathematics of the backward pass of attention, revealing that it implicitly calculates an attention matrix we refer to as "Reversed Attention". We examine the properties of Reversed Attention and demonstrate its ability to elucidate the models' behavior and edit dynamics. In an experimental setup, we showcase the ability of Reversed Attention to directly alter the forward pass of attention, without modifying the model's weights, using a novel method called "attention patching". In addition to enhancing the comprehension of how LM configure attention layers during backpropagation, Reversed Attention maps contribute to a more interpretable backward pass.

Reversed Attention: On The Gradient Descent Of Attention Layers In GPT

TL;DR

This work reveals Reversed Attention (RA), the backward-pass counterpart to forward attention in Transformer-based models, as the implicit attention map produced by the softmax derivative during backpropagation. By deriving Vector-Jacobian Products and the gradient flow for , , , and , the authors show RA forms a sparse, lower-triangular structure that mirrors but distinct from forward attention, providing a lens into gradient-driven editing dynamics. They introduce attention patching, a method to inject RA-informed attention into the forward pass to alter predictions without updating weights, and demonstrate competitive performance with causal mediation while offering faster, interpretable insights. Across GPT2/OPT/Llama2 variants and a suite of tasks, RA serves as a practical tool for identifying influential heads and steering model behavior, albeit within the scope of decoder-only architectures and frame-specific patching.

Abstract

The success of Transformer-based Language Models (LMs) stems from their attention mechanism. While this mechanism has been extensively studied in explainability research, particularly through the attention values obtained during the forward pass of LMs, the backward pass of attention has been largely overlooked. In this work, we study the mathematics of the backward pass of attention, revealing that it implicitly calculates an attention matrix we refer to as "Reversed Attention". We examine the properties of Reversed Attention and demonstrate its ability to elucidate the models' behavior and edit dynamics. In an experimental setup, we showcase the ability of Reversed Attention to directly alter the forward pass of attention, without modifying the model's weights, using a novel method called "attention patching". In addition to enhancing the comprehension of how LM configure attention layers during backpropagation, Reversed Attention maps contribute to a more interpretable backward pass.

Paper Structure

This paper contains 31 sections, 13 equations, 12 figures, 12 tables.

Figures (12)

  • Figure 1: In this paper we examine the attention maps obtained from the backward pass, which we named "Reversed Attention" (RA). This example present the forward and backward pass of a single attention head of GPT2-xl when prompt with "I like Italy and France, I visited the city of". After the model answer "Florence", a city in Italy, we apply a backward pass with "Florence" as the target for the loss and produce the RA maps. Between all the 1200 attention heads this model has, the presented head has the highest RA map's norm. Compared to the forward attention map, the RA map is more sparse and interpretable. This RA demonstrates how the backpropagation attempts to amplify the information from the token "Italy" (red) while reducing the influence of "France" (blue).
  • Figure 2: (a) The norms of the attention maps per head and per layer. (b) Forward and (c) Reversed Attention of the same head from GPT2-small (layer 11, head index 2). This is the attention head with the second highest Reversed Attention norm and we can see it focused on editing the query of "of" (row) and the key of "tomato" (column).
  • Figure 3: RA model editing dynamics. (a) The query matrix $\hat{W}_q$ will be updated with a VJP directed towards the forward pass key of "tomato", while the key matrix $\hat{W}_k$ will be updated with a VJP directed towards the query from the token "of". (b) The latent space of the queries and keys. The circles represent a forward pass query and a key. If their Reversed Attention score is a relatively low negative number, the directions they are moving towards after GD are actually towards one another.
  • Figure 4: Attention patching using Reversed Attention (RA): first we collect the RA maps of the model without applying any model editing (without changing its weights). Later, for each attention head, we add its corresponding RA map to the forward pass attention.
  • Figure 5: The forward and Reversed Attention (RA) maps of an attention head from GPT2-small (layer 11, head index 2), given the editing target "cherry" with the prompt "Cherry tomato is a type of". The pattern presented by the RA map attempts to increase the forward pass attention between the query belonging to "of" and the key of "Cherry", encouraging the model to answer "cherry".
  • ...and 7 more figures