Table of Contents
Fetching ...

Grokking Group Multiplication with Cosets

Dashiell Stander, Qinan Yu, Honglu Fan, Stella Biderman

TL;DR

This work tackles mechanistic interpretability by reverse-engineering a one-hidden-layer network trained to multiply permutations in the symmetric groups $S_5$ and $S_6$. It uncovers coset-based circuits that decompose group arithmetic using conjugate subgroups and validates the mechanism through ablations and targeted causal interventions. The authors critically compare their coset-circuit explanation with the Group Composition via Representations (GCR) hypothesis, arguing that coset-based decoding more faithfully explains the observed behavior while GCR alone cannot account for all findings. The study highlights the challenges of drawing mechanistic conclusions in neural networks and advocates rigorous causal testing to establish robust, interpretable mechanisms for structured tasks.

Abstract

The complex and unpredictable nature of deep neural networks prevents their safe use in many high-stakes applications. There have been many techniques developed to interpret deep neural networks, but all have substantial limitations. Algorithmic tasks have proven to be a fruitful test ground for interpreting a neural network end-to-end. Building on previous work, we completely reverse engineer fully connected one-hidden layer networks that have ``grokked'' the arithmetic of the permutation groups $S_5$ and $S_6$. The models discover the true subgroup structure of the full group and converge on neural circuits that decompose the group arithmetic using the permutation group's subgroups. We relate how we reverse engineered the model's mechanisms and confirmed our theory was a faithful description of the circuit's functionality. We also draw attention to current challenges in conducting interpretability research by comparing our work to Chughtai et al. [4] which alleges to find a different algorithm for this same problem.

Grokking Group Multiplication with Cosets

TL;DR

This work tackles mechanistic interpretability by reverse-engineering a one-hidden-layer network trained to multiply permutations in the symmetric groups and . It uncovers coset-based circuits that decompose group arithmetic using conjugate subgroups and validates the mechanism through ablations and targeted causal interventions. The authors critically compare their coset-circuit explanation with the Group Composition via Representations (GCR) hypothesis, arguing that coset-based decoding more faithfully explains the observed behavior while GCR alone cannot account for all findings. The study highlights the challenges of drawing mechanistic conclusions in neural networks and advocates rigorous causal testing to establish robust, interpretable mechanisms for structured tasks.

Abstract

The complex and unpredictable nature of deep neural networks prevents their safe use in many high-stakes applications. There have been many techniques developed to interpret deep neural networks, but all have substantial limitations. Algorithmic tasks have proven to be a fruitful test ground for interpreting a neural network end-to-end. Building on previous work, we completely reverse engineer fully connected one-hidden layer networks that have ``grokked'' the arithmetic of the permutation groups and . The models discover the true subgroup structure of the full group and converge on neural circuits that decompose the group arithmetic using the permutation group's subgroups. We relate how we reverse engineered the model's mechanisms and confirmed our theory was a faithful description of the circuit's functionality. We also draw attention to current challenges in conducting interpretability research by comparing our work to Chughtai et al. [4] which alleges to find a different algorithm for this same problem.
Paper Structure (47 sections, 16 theorems, 41 equations, 8 figures, 5 tables)

This paper contains 47 sections, 16 theorems, 41 equations, 8 figures, 5 tables.

Key Result

Lemma 4.6

Two cosets $g_1H$ and $g_2H$ are either the same subset of $G$ or disjoint (i.e., $g_1H \bigcap g_2H = \emptyset$).

Figures (8)

  • Figure 1: Model Architecture: we follow the model architecture used by chughtai_toy_2023. The one-hot vectors of left and right permutations pass through separate embeddings. We concatenate the embeddings and pass them through a single fully-connected hidden layer with ReLU activations. An unembedding matrix transforms the activations into logits.
  • Figure 2: A diagram showing the four possible paths through a single neuron (i.e. one row of $\mathbf{R}\mathbf{E}_{r}$) that implements part of a "sign circuit." The model stores whether a permutation is "even" or "odd" in the embeddings, represented in the left or right pre-activation values. The pre-activations are added together and then the ReLU activation is applied. The neuron only fires when the left permutation is even and the right is odd. If the neuron does not fire, then in $1/3$ cases the product is odd and $2/3$ it is even.
  • Figure 3: An illustration of the phenomenon of "concentration on cosets," depicting the 115th neuron from seed 11. We show the evolution of the left pre-activations (the pre-ReLU outputs of a layer) of training on an $F_{20}$ neuron from 100k to 130k steps. The seed of the neuron's functionality is already present at 100k steps, where it fires very strongly and negatively for permutations in the coset $F_{20}(1 \; 2 \; 3 \; 5 \; 4)$, but it takes time for the action of the neuron to "clean up" on the other cosets of $F_{20}$. The distribution found at 130k steps does not change very much afterwards. Noticing this common pattern of neurons taking on these discrete values was a striking piece of evidence that required further investigation.
  • Figure 4: The paired evolution of the the validation loss and $\min_{H \in Sub(H)} C_H$, which encodes the formation of coset circuits. Displayed is the $S_5$ model with random seed 1. Different runs will form coset circuits at different times in training, but the effect is representative.
  • Figure 5: We perform ablations by re-calculating the accuracy after removing any neurons $N_i$ that have $\min_{H \in \mathop{\mathrm{Sub}}\nolimits(G)}C_{H}(N_i)$ greater than (top figure) or less than (bottom figure) the thresholds on the x-axis.
  • ...and 3 more figures

Theorems & Definitions (45)

  • Example 4.1
  • Example 4.2
  • Example 4.3
  • Definition 4.4
  • Definition 4.5
  • Lemma 4.6
  • Lemma 4.7
  • Definition 4.8
  • Lemma 4.9
  • Definition 4.10
  • ...and 35 more