Table of Contents
Fetching ...

Why Can't Transformers Learn Multiplication? Reverse-Engineering Reveals Long-Range Dependency Pitfalls

Xiaoyan Bai, Itamar Pres, Yuntian Deng, Chenhao Tan, Stuart Shieber, Fernanda Viégas, Martin Wattenberg, Andrew Lee

TL;DR

The paper tackles why Transformers struggle with multi-digit multiplication, showing that a model trained with implicit chain-of-thought (ICoT) learns essential long-range dependencies that standard fine-tuning (SFT) lacks. By reverse-engineering ICoT, the authors demonstrate that attention organizes into a sparse, binary-tree-like graph that caches and retrieves pairwise partial products, while digits are represented with Fourier bases, yielding a pentagonal-prism geometry unseen in SFT. They formalize the key intermediate $\hat{c}_k = s_k + r_{k-1}$ with $s_k = \sum_{i+j=k} a_i b_j$ and $c_k = \hat{c}_k \bmod 10$, and show that a simple auxiliary loss predicting the running sum $\hat{c}_k$ provides the inductive bias needed for SFT to succeed on 4$\times$4-digit multiplication. This work highlights a fundamental pitfall in gradient-descent learning for long-range tasks and suggests that task-specific inductive biases can unlock robust long-range reasoning in Transformer models, informing future approaches to arithmetic and other long-horizon commands.

Abstract

Language models are increasingly capable, yet still fail at a seemingly simple task of multi-digit multiplication. In this work, we study why, by reverse-engineering a model that successfully learns multiplication via \emph{implicit chain-of-thought}, and report three findings: (1) Evidence of long-range structure: Logit attributions and linear probes indicate that the model encodes the necessary long-range dependencies for multi-digit multiplication. (2) Mechanism: the model encodes long-range dependencies using attention to construct a directed acyclic graph to ``cache'' and ``retrieve'' pairwise partial products. (3) Geometry: the model implements partial products in attention heads by forming Minkowski sums between pairs of digits, and digits are represented using a Fourier basis, both of which are intuitive and efficient representations that the standard fine-tuning model lacks. With these insights, we revisit the learning dynamics of standard fine-tuning and find that the model converges to a local optimum that lacks the required long-range dependencies. We further validate this understanding by introducing an auxiliary loss that predicts the ``running sum'' via a linear regression probe, which provides an inductive bias that enables the model to successfully learn multi-digit multiplication. In summary, by reverse-engineering the mechanisms of an implicit chain-of-thought model we uncover a pitfall for learning long-range dependencies in Transformers and provide an example of how the correct inductive bias can address this issue.

Why Can't Transformers Learn Multiplication? Reverse-Engineering Reveals Long-Range Dependency Pitfalls

TL;DR

The paper tackles why Transformers struggle with multi-digit multiplication, showing that a model trained with implicit chain-of-thought (ICoT) learns essential long-range dependencies that standard fine-tuning (SFT) lacks. By reverse-engineering ICoT, the authors demonstrate that attention organizes into a sparse, binary-tree-like graph that caches and retrieves pairwise partial products, while digits are represented with Fourier bases, yielding a pentagonal-prism geometry unseen in SFT. They formalize the key intermediate with and , and show that a simple auxiliary loss predicting the running sum provides the inductive bias needed for SFT to succeed on 44-digit multiplication. This work highlights a fundamental pitfall in gradient-descent learning for long-range tasks and suggests that task-specific inductive biases can unlock robust long-range reasoning in Transformer models, informing future approaches to arithmetic and other long-horizon commands.

Abstract

Language models are increasingly capable, yet still fail at a seemingly simple task of multi-digit multiplication. In this work, we study why, by reverse-engineering a model that successfully learns multiplication via \emph{implicit chain-of-thought}, and report three findings: (1) Evidence of long-range structure: Logit attributions and linear probes indicate that the model encodes the necessary long-range dependencies for multi-digit multiplication. (2) Mechanism: the model encodes long-range dependencies using attention to construct a directed acyclic graph to ``cache'' and ``retrieve'' pairwise partial products. (3) Geometry: the model implements partial products in attention heads by forming Minkowski sums between pairs of digits, and digits are represented using a Fourier basis, both of which are intuitive and efficient representations that the standard fine-tuning model lacks. With these insights, we revisit the learning dynamics of standard fine-tuning and find that the model converges to a local optimum that lacks the required long-range dependencies. We further validate this understanding by introducing an auxiliary loss that predicts the ``running sum'' via a linear regression probe, which provides an inductive bias that enables the model to successfully learn multi-digit multiplication. In summary, by reverse-engineering the mechanisms of an implicit chain-of-thought model we uncover a pitfall for learning long-range dependencies in Transformers and provide an example of how the correct inductive bias can address this issue.

Paper Structure

This paper contains 23 sections, 17 equations, 10 figures, 1 table.

Figures (10)

  • Figure 1: Multiplication has long-range dependencies, which can be captured by an intermediate value $\hat{c}_i$, from which both the solution ($c_i$) and carries ($r_i$) can be derived from.
  • Figure 2: Logit Attribution. We test for whether each model has correctly learned long-range dependencies by measuring how sensitive the logits of output digits $c_i$ are to each operand digit (i.e., $a_i, b_j$). This is done by measuring the change in $c_i$'s logits when a single operand digit is perturbed.
  • Figure 3: Linear regression probing results for $\hat{c}$. We probe from the middle of the last Transformer block, after attention heads but before MLPs.
  • Figure 4: Visualization of attention tree to compute $\textbf{c}_2$. Left: Attention maps for selected heads show the first layer "cache" pairwise products ($a_ib_j$) across earlier timesteps, from which the second layer reads from (Not all tree paths are shown). Right: A visualization of the attention tree. Each arc indicates tokens being attended to at specific timesteps. Colored arcs above and below the digits indicate attention patterns from the first and second layers respectively. Example: orange arc indicates that at timestep $b_3$, the model attends to $a_0$ and $b_1$, from which the second layer reads from.
  • Figure 5: 3D PCA of attention head outputs can form Minkowski sums, which in turn can form nested representations. Each color represents a different digit.
  • ...and 5 more figures