Table of Contents
Fetching ...

Explaining Grokking in Transformers through the Lens of Inductive Bias

Jaisidh Singh, Diganta Misra, Antonio Orvieto

TL;DR

The paper investigates grokking in transformers through the lens of inductive bias, focusing on how architectural choices (notably Layer Normalization position) and optimization settings shape the rate and nature of generalization. It introduces a one-layer transformer trained on modular addition and analyzes how LN placement drives distinct biases, including shortcut learning, attention entropy, and the emergence of Fourier-like, periodic solutions. It further shows that optimization factors such as learning rate, weight decay, and readout scale interact with optimizer behavior (e.g., AdamW) in nuanced ways, sometimes confounding lazy-to-rich interpretations and revealing that feature evolution is continuous and compressibility-driven. Across LN configurations and optimization modulators, the results reveal a coherent link between inductive bias, feature compressibility, and generalization, suggesting that grokking in transformers is a nuanced phenomenon governed by continuous feature evolution under specific architectural and optimization biases.

Abstract

We investigate grokking in transformers through the lens of inductive bias: dispositions arising from architecture or optimization that let the network prefer one solution over another. We first show that architectural choices such as the position of Layer Normalization (LN) strongly modulates grokking speed. This modulation is explained by isolating how LN on specific pathways shapes shortcut-learning and attention entropy. Subsequently, we study how different optimization settings modulate grokking, inducing distinct interpretations of previously proposed controls such as readout scale. Particularly, we find that using readout scale as a control for lazy training can be confounded by learning rate and weight decay in our setting. Accordingly, we show that features evolve continuously throughout training, suggesting grokking in transformers can be more nuanced than a lazy-to-rich transition of the learning regime. Finally, we show how generalization predictably emerges with feature compressibility in grokking, across different modulators of inductive bias. Our code is released at https://tinyurl.com/y52u3cad.

Explaining Grokking in Transformers through the Lens of Inductive Bias

TL;DR

The paper investigates grokking in transformers through the lens of inductive bias, focusing on how architectural choices (notably Layer Normalization position) and optimization settings shape the rate and nature of generalization. It introduces a one-layer transformer trained on modular addition and analyzes how LN placement drives distinct biases, including shortcut learning, attention entropy, and the emergence of Fourier-like, periodic solutions. It further shows that optimization factors such as learning rate, weight decay, and readout scale interact with optimizer behavior (e.g., AdamW) in nuanced ways, sometimes confounding lazy-to-rich interpretations and revealing that feature evolution is continuous and compressibility-driven. Across LN configurations and optimization modulators, the results reveal a coherent link between inductive bias, feature compressibility, and generalization, suggesting that grokking in transformers is a nuanced phenomenon governed by continuous feature evolution under specific architectural and optimization biases.

Abstract

We investigate grokking in transformers through the lens of inductive bias: dispositions arising from architecture or optimization that let the network prefer one solution over another. We first show that architectural choices such as the position of Layer Normalization (LN) strongly modulates grokking speed. This modulation is explained by isolating how LN on specific pathways shapes shortcut-learning and attention entropy. Subsequently, we study how different optimization settings modulate grokking, inducing distinct interpretations of previously proposed controls such as readout scale. Particularly, we find that using readout scale as a control for lazy training can be confounded by learning rate and weight decay in our setting. Accordingly, we show that features evolve continuously throughout training, suggesting grokking in transformers can be more nuanced than a lazy-to-rich transition of the learning regime. Finally, we show how generalization predictably emerges with feature compressibility in grokking, across different modulators of inductive bias. Our code is released at https://tinyurl.com/y52u3cad.
Paper Structure (42 sections, 7 equations, 9 figures, 2 tables)

This paper contains 42 sections, 7 equations, 9 figures, 2 tables.

Figures (9)

  • Figure 1: Different pathways express different biases in one-layer transformers. Applying layer normalization (LN) at specific positions changes the inductive bias of the network and consequently, grokking behavior.
  • Figure 2: (a) Loss curves for various LN positions averaged across $5$ seeds respectively. MLP Pre-LN and Pre exhibit slingshots which we plot with lower opacity. (b) Evolution of $\mathcal{C}_{\text{Fourier}}$ averaged across $5$ seeds along with the points where each configuration generalizes (test loss than $0.01$). (c) LN positions that generalize earlier show faster emergence of periodic Fourier-based structure in $\widetilde{W}_E$.
  • Figure 3: (a) The No LN configuration is sensitive to the norm of MLP inputs in the train loss, shown by the train loss incurred when we perturb the MLP forward pass from $\operatorname{MLP}(\bullet)$ (dashed lines) to $\operatorname{MLP}(\bullet / \|\bullet\|)$ (dots). Notably, SR-train reduces as when the network groks, showing the importance of feature scale for generalization. (b) LN on the inputs to the value channel of $\operatorname{MHSA}(\bullet)$ leads to fast grokking, whereas LN on query and key inputs actively hurts the network's ability to generalize. (c) LN on query and key inputs reduces the entropy of attention scores by removing embedding scale that acts as a degree of freedom for the network.
  • Figure 4: (a) Increasing learning rate $\eta$ while keeping weight decay constant leads to faster grokking. (b) Higher learning rates under fixed weight decay leads to significant differences in parameter update norm $\|\Delta \theta\|$ across training. (c) Similarly, increasing weight decay under constant learning rate leads faster grokking. (d) Higher weight decay strength also leads to large variations in parameter updates.
  • Figure 5: (a) Effect of varying learning rate as $\eta=\eta_0/\alpha^2$ on loss: higher readout scale $\alpha$ leads to slower grokking. (b) Parameter updates are not comparable across $\alpha$ under $\eta=\eta_0/\alpha^2$, and in fact diverge significantly. (c) Fixing learning rate to $\eta_0$ now leads to slower grokking with lower values of readout scale $\alpha$. (d) Constant learning rate $\eta_0$ lets parameter updates remain comparable across $\alpha$.
  • ...and 4 more figures