Table of Contents
Fetching ...

Learning Modular Exponentiation with Transformers

David Demitri Africa, Sara M. Kapoor, Theo Simon Sorg, Challenger Mishra

TL;DR

The paper investigates how transformers learn modular exponentiation, treating interpretability as a core goal. By training a 4-layer encoder-decoder Transformer on $a^b \equiv d \pmod c$ and employing reciprocal operand sampling along with base-$B$ digit representations, the authors show robust generalization and grokking-like surges for related moduli. A key finding is that a minimal circuit composed of final-layer attention heads suffices for regular exponentiation, suggesting specialized high-level computation rather than distributed symbolic processing. These results advance mechanistic interpretability in neural arithmetic, demonstrating both concrete learning dynamics and identifiable circuits, though within a synthetic, small-scale setting and highlighting avenues for scaling to cryptographic-strength inputs.

Abstract

Modular exponentiation is crucial to number theory and cryptography, yet remains largely unexplored from a mechanistic interpretability standpoint. We train a 4-layer encoder-decoder Transformer model to perform this operation and investigate the emergence of numerical reasoning during training. Utilizing principled sampling strategies, PCA-based embedding analysis, and activation patching, we examine how number-theoretic properties are encoded within the model. We find that reciprocal operand training leads to strong performance gains, with sudden generalization across related moduli. These synchronized accuracy surges reflect grokking-like dynamics, suggesting the model internalizes shared arithmetic structure. We also find a subgraph consisting entirely of attention heads in the final layer sufficient to achieve full performance on the task of regular exponentiation. These results suggest that transformer models learn modular arithmetic through specialized computational circuits, paving the way for more interpretable and efficient neural approaches to modular exponentiation.

Learning Modular Exponentiation with Transformers

TL;DR

The paper investigates how transformers learn modular exponentiation, treating interpretability as a core goal. By training a 4-layer encoder-decoder Transformer on and employing reciprocal operand sampling along with base- digit representations, the authors show robust generalization and grokking-like surges for related moduli. A key finding is that a minimal circuit composed of final-layer attention heads suffices for regular exponentiation, suggesting specialized high-level computation rather than distributed symbolic processing. These results advance mechanistic interpretability in neural arithmetic, demonstrating both concrete learning dynamics and identifiable circuits, though within a synthetic, small-scale setting and highlighting avenues for scaling to cryptographic-strength inputs.

Abstract

Modular exponentiation is crucial to number theory and cryptography, yet remains largely unexplored from a mechanistic interpretability standpoint. We train a 4-layer encoder-decoder Transformer model to perform this operation and investigate the emergence of numerical reasoning during training. Utilizing principled sampling strategies, PCA-based embedding analysis, and activation patching, we examine how number-theoretic properties are encoded within the model. We find that reciprocal operand training leads to strong performance gains, with sudden generalization across related moduli. These synchronized accuracy surges reflect grokking-like dynamics, suggesting the model internalizes shared arithmetic structure. We also find a subgraph consisting entirely of attention heads in the final layer sufficient to achieve full performance on the task of regular exponentiation. These results suggest that transformer models learn modular arithmetic through specialized computational circuits, paving the way for more interpretable and efficient neural approaches to modular exponentiation.

Paper Structure

This paper contains 21 sections, 6 equations, 5 figures, 1 table.

Figures (5)

  • Figure 1: Left: Validation and test accuracy over 3000 epochs for the reciprocal operands model. Reciprocal sampling (log-uniform distribution of operands) enables effective learning of modular exponentiation, with validation accuracy reaching $\sim$84% and test accuracy $\sim$80%. Middle: Test accuracy comparison across four numerical bases over 1000 epochs. Composite bases (999, 1000) substantially outperform prime bases (1013, 1279), with bases 999 and 1000 reaching $\sim$60% accuracy compared to $\sim$49% for prime bases. Right: Validation accuracy shows the same pattern, confirming that the composite base advantage generalizes across both evaluation sets. Base choice significantly impacts learning dynamics and final performance.
  • Figure 2: Synchronized grokking for multiples of 23 during epochs 1725--1750 (highlighted region). Moduli 23, 46 ($2\times23$), and 69 ($3\times23$) exhibit simultaneous accuracy jumps from $\sim$20% to near-perfect performance, demonstrating that the transformer discovers and exploits mathematical relationships between related moduli. Control moduli 47 and 83 (not multiples of 23) exhibit a different learning pattern with gradual improvement, while overall accuracy remains high ($\sim$83%) throughout training.
  • Figure 3: Left: KL divergence heatmap showing causal importance of each attention head across the 4-layer decoder. Warmer colors indicate higher KL divergence between clean and patched activations, reflecting greater causal impact on model predictions. Final-layer heads (layer 3) exhibit substantially higher KL divergence ($\sim$3--4.5) compared to earlier layers ($\sim$0.1--0.4), indicating that the circuit for regular exponentiation (when $c > a^b$) is concentrated in the final decoder layer. Right: Accuracy comparison between the full model and the minimal circuit consisting only of final-layer attention heads. Both achieve 70% accuracy on regular exponentiation tasks, demonstrating that earlier layers contribute negligibly to this computation and confirming functional specialization in the network architecture.
  • Figure 4: PCA 3D projections of token embeddings, colored by number-theoretic properties, before and after grokking. Bottom row shows multiples in 2D.
  • Figure 5: Test plots from training runs for 18 different moduli.