Table of Contents
Fetching ...

Interpreting Key Mechanisms of Factual Recall in Transformer-Based Language Models

Ang Lv, Yuhan Chen, Kaiyi Zhang, Yulong Wang, Lifeng Liu, Ji-Rong Wen, Jian Xie, Rui Yan

TL;DR

The paper investigates how Transformer-based language models perform factual recall by proposing a residual-stream–driven, three-stage mechanism: task-specific heads extract arguments, an MLP amplifies or erases head signals (activation), and deeper MLPs implement a function-like guidance toward the correct answer. It introduces causal mediation analysis and activation patching to identify influential modules and develops a regression-based method to decompose MLP outputs into interpretable components, revealing mover heads and a task-aware intercept that achieves function application. A universal anti-overconfidence mechanism in the final layer is uncovered, along with practical mitigation strategies that improve recall confidence across multiple models (GPT-2, OPT, Llama-2) and in zero- and few-shot settings. The findings provide mechanistic insight into how knowledge is retrieved and directed within deep transformers, with implications for model design and alignment that rely on reliable factual recall.

Abstract

In this paper, we delve into several mechanisms employed by Transformer-based language models (LLMs) for factual recall tasks. We outline a pipeline consisting of three major steps: (1) Given a prompt ``The capital of France is,'' task-specific attention heads extract the topic token, such as ``France,'' from the context and pass it to subsequent MLPs. (2) As attention heads' outputs are aggregated with equal weight and added to the residual stream, the subsequent MLP acts as an ``activation,'' which either erases or amplifies the information originating from individual heads. As a result, the topic token ``France'' stands out in the residual stream. (3) A deep MLP takes ``France'' and generates a component that redirects the residual stream towards the direction of the correct answer, i.e., ``Paris.'' This procedure is akin to applying an implicit function such as ``get\_capital($X$),'' and the argument $X$ is the topic token information passed by attention heads. To achieve the above quantitative and qualitative analysis for MLPs, we proposed a novel analytic method aimed at decomposing the outputs of the MLP into components understandable by humans. Additionally, we observed a universal anti-overconfidence mechanism in the final layer of models, which suppresses correct predictions. We mitigate this suppression by leveraging our interpretation to improve factual recall confidence. The above interpretations are evaluated across diverse tasks spanning various domains of factual knowledge, using various language models from the GPT-2 families, 1.3B OPT, up to 7B Llama-2, and in both zero- and few-shot setups.

Interpreting Key Mechanisms of Factual Recall in Transformer-Based Language Models

TL;DR

The paper investigates how Transformer-based language models perform factual recall by proposing a residual-stream–driven, three-stage mechanism: task-specific heads extract arguments, an MLP amplifies or erases head signals (activation), and deeper MLPs implement a function-like guidance toward the correct answer. It introduces causal mediation analysis and activation patching to identify influential modules and develops a regression-based method to decompose MLP outputs into interpretable components, revealing mover heads and a task-aware intercept that achieves function application. A universal anti-overconfidence mechanism in the final layer is uncovered, along with practical mitigation strategies that improve recall confidence across multiple models (GPT-2, OPT, Llama-2) and in zero- and few-shot settings. The findings provide mechanistic insight into how knowledge is retrieved and directed within deep transformers, with implications for model design and alignment that rely on reliable factual recall.

Abstract

In this paper, we delve into several mechanisms employed by Transformer-based language models (LLMs) for factual recall tasks. We outline a pipeline consisting of three major steps: (1) Given a prompt ``The capital of France is,'' task-specific attention heads extract the topic token, such as ``France,'' from the context and pass it to subsequent MLPs. (2) As attention heads' outputs are aggregated with equal weight and added to the residual stream, the subsequent MLP acts as an ``activation,'' which either erases or amplifies the information originating from individual heads. As a result, the topic token ``France'' stands out in the residual stream. (3) A deep MLP takes ``France'' and generates a component that redirects the residual stream towards the direction of the correct answer, i.e., ``Paris.'' This procedure is akin to applying an implicit function such as ``get\_capital(),'' and the argument is the topic token information passed by attention heads. To achieve the above quantitative and qualitative analysis for MLPs, we proposed a novel analytic method aimed at decomposing the outputs of the MLP into components understandable by humans. Additionally, we observed a universal anti-overconfidence mechanism in the final layer of models, which suppresses correct predictions. We mitigate this suppression by leveraging our interpretation to improve factual recall confidence. The above interpretations are evaluated across diverse tasks spanning various domains of factual knowledge, using various language models from the GPT-2 families, 1.3B OPT, up to 7B Llama-2, and in both zero- and few-shot setups.
Paper Structure (39 sections, 7 equations, 19 figures, 8 tables)

This paper contains 39 sections, 7 equations, 19 figures, 8 tables.

Figures (19)

  • Figure 1: The key mechanisms of factual recall employed by Transformer-based language models. Please refer to §\ref{['sec:background']} for detailed notations.
  • Figure 2: Subfigures (a), (b), and (c) are toy diagrams of causal mediation analysis in discovering important circuits in a neural network. Colors distinguish different node values. Subfigure (d) illustrates an activation patching example in studying the node affecting the correct capital city prediction, as detailed in §\ref{['sec:background2']}.
  • Figure 3: (a) Probabilities of $X$ and $Y$ decoded at each layer. The shaded regions indicate variances. (b) The effect of patching $\textbf{a}^{l,h}_{t=-1} \rightarrow \textbf{r}^{11}_{\text{post},t=-1}$. (c) Value weighted attention pattern of L9H8 and L10H0. (d) L9H8 and L10H0's attention to $X$ is proportional to the projection value of their output along $\textbf{W}_{U}[X]$, indicating they are moving country names to the final position.
  • Figure 4: (a) The probabilities of $X$ and $Y$ decoded from various vectors. (b) Cosine similarity between the vectors. Note that although the cosine similarity between $\textbf{r}^{10}_{\text{post}}$ and $\textbf{W}_{U}[X]$ remains higher than that with $\textbf{W}_{U}[Y]$, when taking into account the norm, the logit value of $Y$ is higher. (c) A simplified example illustrating the relationships among vectors in 3D space. As MLP10 generates a component $\textbf{b}^{10,\text{proj}}$ that aligns with $\textbf{W}_{U}[Y] - \textbf{W}_{U}[X]$ in direction, adding $\textbf{b}^{10,\text{proj}}$ to the residual stream achieves the "function application," causing the probability of $Y$ to surpass that of $X$ for the first time during the forward pass.
  • Figure 5: The sub-figures illustrate the probability dynamics of $X$ and $Y$, alongside influential heads impacting final logits across zero, one, and two-shot settings. The fundamental mechanisms detected in the zero-shot scenario still work in few-shot settings.
  • ...and 14 more figures