Table of Contents
Fetching ...

Backward Lens: Projecting Language Model Gradients into the Vocabulary Space

Shahar Katz, Yonatan Belinkov, Mor Geva, Lior Wolf

TL;DR

It is proved that a gradient matrix can be cast as a low-rank linear combination of its forward and backward passes’ inputs and developed methods to project these gradients into vocabulary items and explore the mechanics of how new information is stored in the LMs’ neurons.

Abstract

Understanding how Transformer-based Language Models (LMs) learn and recall information is a key goal of the deep learning community. Recent interpretability methods project weights and hidden states obtained from the forward pass to the models' vocabularies, helping to uncover how information flows within LMs. In this work, we extend this methodology to LMs' backward pass and gradients. We first prove that a gradient matrix can be cast as a low-rank linear combination of its forward and backward passes' inputs. We then develop methods to project these gradients into vocabulary items and explore the mechanics of how new information is stored in the LMs' neurons.

Backward Lens: Projecting Language Model Gradients into the Vocabulary Space

TL;DR

It is proved that a gradient matrix can be cast as a low-rank linear combination of its forward and backward passes’ inputs and developed methods to project these gradients into vocabulary items and explore the mechanics of how new information is stored in the LMs’ neurons.

Abstract

Understanding how Transformer-based Language Models (LMs) learn and recall information is a key goal of the deep learning community. Recent interpretability methods project weights and hidden states obtained from the forward pass to the models' vocabularies, helping to uncover how information flows within LMs. In this work, we extend this methodology to LMs' backward pass and gradients. We first prove that a gradient matrix can be cast as a low-rank linear combination of its forward and backward passes' inputs. We then develop methods to project these gradients into vocabulary items and explore the mechanics of how new information is stored in the LMs' neurons.
Paper Structure (40 sections, 3 theorems, 13 equations, 24 figures, 7 tables)

This paper contains 40 sections, 3 theorems, 13 equations, 24 figures, 7 tables.

Key Result

Lemma 4.1

Given a sequence of inputs of length $n$, a parametric matrix $W$ and a loss function $L$, the gradient $\frac{\partial L}{\partial W}$ produced by a backward pass is a matrix with a rank of $n$ or lower.

Figures (24)

  • Figure 1: An illustration depicting the tokens promoted by a single LM's MLP layer and its gradient during the forward and backward pass when editing the model to answer "Paris" for the prompt "Lionel Messi plays for". The gradients (in green) of the first MLP matrix, $FF_1$, attempt to imprint into the model's weight (in blue) the information that $FF_1$ encountered during the forward pass. Utilizing a vocabulary projection method, we reveal that this information represents the token "team". The gradients of the second MLP matrix, $FF_2$, aim to shift the information encoded within $FF_2$ towards the embedding of the new target.
  • Figure 2: The calculation of gradient matrix by the outer product of $x^\top \cdot \delta$. Each row consists of the same values, but above we describe the matrix as a span of $\delta$, while below as a span of $x^\top$. The displayed vectors are presented transposed to emphasize the spanning effect.
  • Figure 3: The Imprint and Shift mechanism of backpropagation."grad" represent a single neuron from a gradient matrix. The color of $FF_1$ grad is the same as the forward-pass input, while $FF_2$ is the same as the new target embedding, suggesting that they are similar to each other.
  • Figure 4: The percentage of occurrences where the rank of $FF_1$'s gradient equals the length of the prompt used for editing. To show different models in the same plot, we normalize the layer indices. Except for the last layer, all layers and models exhibit the above equality more than $98.5\%$ of the time.
  • Figure 5: The gradient of GPT2-small $FF_2$ when editing the model to answer "Paris" for the prompt "Obama grew up in". Each cell shows the Logit Lens projection of the gradient's VJP ($\delta_i$) for a token input and a layer. Non-English characters are replaced with a question mark, and long tokens are truncated with "..". According to \ref{['Trap and Shift']}, instead of showing the most probable token in each cell, we display the least probable one. The color indicates the norm of the VJP, with white cells indicating that almost no editing is done in practice.
  • ...and 19 more figures

Theorems & Definitions (5)

  • Lemma 4.1
  • proof
  • Lemma 5.1
  • Lemma 5.1
  • proof