Table of Contents
Fetching ...

Transformers Can Do Arithmetic with the Right Embeddings

Sean McLeish, Arpit Bansal, Alex Stein, Neel Jain, John Kirchenbauer, Brian R. Bartoldson, Bhavya Kailkhura, Abhinav Bhatele, Jonas Geiping, Avi Schwarzschild, Tom Goldstein

TL;DR

The paper tackles the challenge of arithmetic reasoning in transformers by identifying digit-position tracking as a critical bottleneck and proposing Abacus Embeddings that encode digit significance within each number. When combined with input injection and looped transformer architectures, these embeddings enable strong length generalization, achieving up to $99\%$ accuracy on 100-digit additions and extending to $120$-digit problems, as well as improving performance on multiplication and sorting. The authors also demonstrate the benefits of recurrence, showing that looped transformers with progressive loss can outperform standard architectures across ID and OOD settings. These results suggest a path toward more capable algorithmic reasoning in transformers without external tools, with potential applicability to a range of numerically intensive tasks and broader relational reasoning challenges.

Abstract

The poor performance of transformers on arithmetic tasks seems to stem in large part from their inability to keep track of the exact position of each digit inside of a large span of digits. We mend this problem by adding an embedding to each digit that encodes its position relative to the start of the number. In addition to the boost these embeddings provide on their own, we show that this fix enables architectural modifications such as input injection and recurrent layers to improve performance even further. With positions resolved, we can study the logical extrapolation ability of transformers. Can they solve arithmetic problems that are larger and more complex than those in their training data? We find that training on only 20 digit numbers with a single GPU for one day, we can reach state-of-the-art performance, achieving up to 99% accuracy on 100 digit addition problems. Finally, we show that these gains in numeracy also unlock improvements on other multi-step reasoning tasks including sorting and multiplication.

Transformers Can Do Arithmetic with the Right Embeddings

TL;DR

The paper tackles the challenge of arithmetic reasoning in transformers by identifying digit-position tracking as a critical bottleneck and proposing Abacus Embeddings that encode digit significance within each number. When combined with input injection and looped transformer architectures, these embeddings enable strong length generalization, achieving up to accuracy on 100-digit additions and extending to -digit problems, as well as improving performance on multiplication and sorting. The authors also demonstrate the benefits of recurrence, showing that looped transformers with progressive loss can outperform standard architectures across ID and OOD settings. These results suggest a path toward more capable algorithmic reasoning in transformers without external tools, with potential applicability to a range of numerically intensive tasks and broader relational reasoning challenges.

Abstract

The poor performance of transformers on arithmetic tasks seems to stem in large part from their inability to keep track of the exact position of each digit inside of a large span of digits. We mend this problem by adding an embedding to each digit that encodes its position relative to the start of the number. In addition to the boost these embeddings provide on their own, we show that this fix enables architectural modifications such as input injection and recurrent layers to improve performance even further. With positions resolved, we can study the logical extrapolation ability of transformers. Can they solve arithmetic problems that are larger and more complex than those in their training data? We find that training on only 20 digit numbers with a single GPU for one day, we can reach state-of-the-art performance, achieving up to 99% accuracy on 100 digit addition problems. Finally, we show that these gains in numeracy also unlock improvements on other multi-step reasoning tasks including sorting and multiplication.
Paper Structure (40 sections, 3 equations, 22 figures, 4 tables)

This paper contains 40 sections, 3 equations, 22 figures, 4 tables.

Figures (22)

  • Figure 1: Zero shot exact match accuracy on addition using depth sixteen transformer (decoder only) models trained on operands of up to 20 digits. Compared to state-of-the-art embeddings (left), our new Abacus Embeddings (right) dramatically improve generalization to unseen digit lengths. The interior of the red square denotes the training distribution. Accuracies are averaged over three trials.
  • Figure 2: Visualization of data formats and positional embeddings. Abacus Embeddings give the same positional embeddings to all digits of the same significance.
  • Figure 3: Left: Mean exact match accuracy of three models of depth sixteen on size $20$ data, varying the architecture and embeddings. Abacus Embeddings improve accuracy for addition over FIRE and NoPE Embeddings. Right: Mean exact match accuracy of three models of effective depth sixteen on size $40$ data, varying over NoPE or FIRE embeddings and architectures. Recurrent looped transformer models improve accuracy for addition for both the FIRE and NoPE embeddings. Looped transformer (LT): Weight tied decoder layers, with input injection and progressive loss. Standard Transformer (ST): Stacked decoder only layers. Standard Transformer with Input Injection (ST w/ II): Standard Transformer with input features added to the hidden representation between each decoder layer.
  • Figure 4: Varying the size of the recurrent block, while maintaining an effective depth of $16$ and training on size $20$ data. We see that a recurrent model with eight layers in the recurrent block and two recurrences is the most accurate of all effective depth $16$ models, halving the error rate of a standard model with input injection in the OOD evaluation. (See Figure \ref{['fig:app_vary_weight_tie_inc_fire_nope']} for results with FIRE and NoPE.)
  • Figure 5: Models which have 8 layers in recurrent block and 2 recurrences, trained on size 20 addition and subtraction data, each line is the average of 3 models. We see that it is possible to have extreme generalization whilst learning multiple tasks.
  • ...and 17 more figures