Table of Contents
Fetching ...

A Mechanistic Interpretation of Arithmetic Reasoning in Language Models using Causal Mediation Analysis

Alessandro Stolfo, Yonatan Belinkov, Mrinmaya Sachan

TL;DR

This paper develops a causal mediation framework to mechanistically interpret arithmetic reasoning in Transformer LMs. By intervening on mediators such as $m^{(l)}_t$ and $a^{(l)}_t$ and quantifying indirect effects, it shows that operands and operators are routed through mid-sequence attention to the last token, where late MLPs encode the result into the residual stream. The authors identify four key activation sites, demonstrate that three-operand queries require fine-tuning to reveal similar late-stage dynamics, and show that the observed information-flow patterns are specific to arithmetic tasks when compared to number retrieval and factual knowledge tasks. The findings offer actionable insights for targeted training, pruning, and inference-time corrections in math-oriented reasoning with language models.

Abstract

Mathematical reasoning in large language models (LMs) has garnered significant attention in recent work, but there is a limited understanding of how these models process and store information related to arithmetic tasks within their architecture. In order to improve our understanding of this aspect of language models, we present a mechanistic interpretation of Transformer-based LMs on arithmetic questions using a causal mediation analysis framework. By intervening on the activations of specific model components and measuring the resulting changes in predicted probabilities, we identify the subset of parameters responsible for specific predictions. This provides insights into how information related to arithmetic is processed by LMs. Our experimental results indicate that LMs process the input by transmitting the information relevant to the query from mid-sequence early layers to the final token using the attention mechanism. Then, this information is processed by a set of MLP modules, which generate result-related information that is incorporated into the residual stream. To assess the specificity of the observed activation dynamics, we compare the effects of different model components on arithmetic queries with other tasks, including number retrieval from prompts and factual knowledge questions.

A Mechanistic Interpretation of Arithmetic Reasoning in Language Models using Causal Mediation Analysis

TL;DR

This paper develops a causal mediation framework to mechanistically interpret arithmetic reasoning in Transformer LMs. By intervening on mediators such as and and quantifying indirect effects, it shows that operands and operators are routed through mid-sequence attention to the last token, where late MLPs encode the result into the residual stream. The authors identify four key activation sites, demonstrate that three-operand queries require fine-tuning to reveal similar late-stage dynamics, and show that the observed information-flow patterns are specific to arithmetic tasks when compared to number retrieval and factual knowledge tasks. The findings offer actionable insights for targeted training, pruning, and inference-time corrections in math-oriented reasoning with language models.

Abstract

Mathematical reasoning in large language models (LMs) has garnered significant attention in recent work, but there is a limited understanding of how these models process and store information related to arithmetic tasks within their architecture. In order to improve our understanding of this aspect of language models, we present a mechanistic interpretation of Transformer-based LMs on arithmetic questions using a causal mediation analysis framework. By intervening on the activations of specific model components and measuring the resulting changes in predicted probabilities, we identify the subset of parameters responsible for specific predictions. This provides insights into how information related to arithmetic is processed by LMs. Our experimental results indicate that LMs process the input by transmitting the information relevant to the query from mid-sequence early layers to the final token using the attention mechanism. Then, this information is processed by a set of MLP modules, which generate result-related information that is incorporated into the residual stream. To assess the specificity of the observed activation dynamics, we compare the effects of different model components on arithmetic queries with other tasks, including number retrieval from prompts and factual knowledge questions.
Paper Structure (30 sections, 7 equations, 19 figures, 4 tables)

This paper contains 30 sections, 7 equations, 19 figures, 4 tables.

Figures (19)

  • Figure 1: Visualization of our findings. We trace the flow of numerical information within Transformer-based LMs: given an input query, the model processes the representations of numbers and operators with early layers (A). Then, the relevant information is conveyed by the attention mechanism to the end of the input sequence (B). Here, it is processed by late MLP modules, which output result-related information into the residual stream (C).
  • Figure 2: By intervening on the activation values of specific components within a language model and computing the corresponding effects, we identify the subset of parameters responsible for specific predictions.
  • Figure 3: Indirect effect (IE) measured within GPT-J. Figures (a) and (b) illustrate the flow of information related to both the operands and the result of the queries, while the effect displayed in Figures (c) and (d) is related to the operands only (the result is kept unchanged). Figures (e--h) show a re-scaled visualization of the effects at the last token for each of the four heatmaps (a--d). The difference in the effect registered for the MLPs at layers 15--25 between figures (a) and (c) illustrates the role of these components in producing result-related information.
  • Figure 4: Indirect effect (IE) on three-operand queries for different MLP modules in Pythia 2.8B before and after fine-tuning. The effect produced by the last-token mid-late MLP activation site emerges with fine-tuning. Results for the attention are reported in Appendix \ref{['sec:additional_res']}.
  • Figure 5: Indirect effect measured on the MLPs of GPT-J for predictions on the number retrieval task.
  • ...and 14 more figures