Table of Contents
Fetching ...

Understanding Addition in Transformers

Philip Quirke, Fazl Barez

TL;DR

This work probes how a deliberately small Transformer learns to perform $n$-digit addition, revealing that the model decomposes the problem into digit-specific, parallel sub-tasks and employs distinct algorithms for different digit positions. By combining a formal mathematical framework with training and prediction analyses and targeted ablations, the authors show how Base Add, Make Carry 1, and Make Sum 9 underlie the computation, with Use Carry 1 and Use Sum 9 propagating carries in a time-ordered fashion; a rare US9 cascade explains high-loss edge cases. The study demonstrates a robust, reusable digit-wise addition circuit that emerges across seeds and training regimes, contributing to mechanistic interpretability and providing a blueprint for analyzing more complex tasks in shallow and deeper transformers. The findings have implications for safety, reliability, and the deliberate construction of interpretable, modular computations within neural networks, and point toward extending the framework to subtraction, multiplication, and other symbolic domains.

Abstract

Understanding the inner workings of machine learning models like Transformers is vital for their safe and ethical use. This paper provides a comprehensive analysis of a one-layer Transformer model trained to perform n-digit integer addition. Our findings suggest that the model dissects the task into parallel streams dedicated to individual digits, employing varied algorithms tailored to different positions within the digits. Furthermore, we identify a rare scenario characterized by high loss, which we explain. By thoroughly elucidating the model's algorithm, we provide new insights into its functioning. These findings are validated through rigorous testing and mathematical modeling, thereby contributing to the broader fields of model understanding and interpretability. Our approach opens the door for analyzing more complex tasks and multi-layer Transformer models.

Understanding Addition in Transformers

TL;DR

This work probes how a deliberately small Transformer learns to perform -digit addition, revealing that the model decomposes the problem into digit-specific, parallel sub-tasks and employs distinct algorithms for different digit positions. By combining a formal mathematical framework with training and prediction analyses and targeted ablations, the authors show how Base Add, Make Carry 1, and Make Sum 9 underlie the computation, with Use Carry 1 and Use Sum 9 propagating carries in a time-ordered fashion; a rare US9 cascade explains high-loss edge cases. The study demonstrates a robust, reusable digit-wise addition circuit that emerges across seeds and training regimes, contributing to mechanistic interpretability and providing a blueprint for analyzing more complex tasks in shallow and deeper transformers. The findings have implications for safety, reliability, and the deliberate construction of interpretable, modular computations within neural networks, and point toward extending the framework to subtraction, multiplication, and other symbolic domains.

Abstract

Understanding the inner workings of machine learning models like Transformers is vital for their safe and ethical use. This paper provides a comprehensive analysis of a one-layer Transformer model trained to perform n-digit integer addition. Our findings suggest that the model dissects the task into parallel streams dedicated to individual digits, employing varied algorithms tailored to different positions within the digits. Furthermore, we identify a rare scenario characterized by high loss, which we explain. By thoroughly elucidating the model's algorithm, we provide new insights into its functioning. These findings are validated through rigorous testing and mathematical modeling, thereby contributing to the broader fields of model understanding and interpretability. Our approach opens the door for analyzing more complex tasks and multi-layer Transformer models.
Paper Structure (21 sections, 26 figures, 6 tables, 1 algorithm)

This paper contains 21 sections, 26 figures, 6 tables, 1 algorithm.

Figures (26)

  • Figure 1: Illustration of the transformer model's attention pattern when adding two 5-digit integers. The model attends to digit pairs sequentially from left to right, resulting in a “double staircase" pattern across rows. A: The 5 digit question is revealed token by token. The “10s of thousands" digit is revealed first. B: From the “=" token, the model attention heads focus on successive pairs of digits, giving a “double staircase" attention pattern. C: The 3 heads are time-offset from each other by 1 token such that, in each row, data from 3 tokens is available. D: To calculate A3, the 3 heads do independent simple calculations on D3, D2 and D1. The results are combined by the MLP layer using trigrams. A3 is calculated one token before it is needed. This approach applies to all answer digits, with the first and last digits using slight variations of the approach.
  • Figure 2: For 5-digit integer addition, these per-digit training loss curves show the model trains each answer digit semi-independently. The first answer digit A5 which is always 1 or 0 is learnt much more quickly than other digits.
  • Figure 3: We refer to individual tokens in a 5-digit addition question as D4, .. D0, and D'4, .., D'0 and the answer tokens as A5, .., A0.
  • Figure 4: The attention pattern, for a model with 3 attention heads, performing a single 5 digit addition. The pattern is 18 by 18 squares (as 54321+77779=132100 is 18 tokens). Time proceeds vertically downwards, with one additional token being revealed horizontally at each row, giving the overall triangle shape. After the question is fully revealed (at row 11), each head starts attending to pairs of question digits from left to right (i.e. high-value digits before lower-value digits) giving the “double staircase" shape. The three heads attend to a given digit pair in three different rows, giving a time ordering of heads.
  • Figure 5: The mathematical framework (our method) predicts that during training, tasks are learnt for each digit independently, progressively increasing per digit accuracy (i.e. decreasing loss) shown as percentages. Mathematical rules cause dependencies between digits, giving an predicted ordering for perfect (i.e. zero loss) addition. The chain of blue squares relate to questions like 99999 + 00001 = 100000 where the MC1 in digit 0 causes US9 cascades through multiple other digits.
  • ...and 21 more figures