Clustering Head: A Visual Case Study of the Training Dynamics in Transformers
Ambroise Odonnat, Wassim Bouaziz, Vivien Cabannes
TL;DR
The paper investigates how transformers learn a controlled sparse modular addition task by visualizing training dynamics in a small model with $d=2$ embeddings. It introduces a visual sandbox and the concept of clustering heads, which realize invariances that map clustered sequence embeddings to outputs, revealing a two-phase learning process: first structure the sequence embeddings, then fit a classifier. Loss spikes are analyzed as consequences of high curvature in normalization and feed-forward layers, with gradient norms providing a diagnostic link to training stability. The work combines empirical observations with theoretical insights, showing clustering heads can arise as gradient-descent fixed points and highlighting the roles of initialization and curriculum learning in enabling robust invariants. These findings offer guidance for understanding and stabilizing training in larger transformers and motivate future formal results around two-stage learning and circuit-transferability.
Abstract
This paper introduces the sparse modular addition task and examines how transformers learn it. We focus on transformers with embeddings in $\R^2$ and introduce a visual sandbox that provides comprehensive visualizations of each layer throughout the training process. We reveal a type of circuit, called "clustering heads," which learns the problem's invariants. We analyze the training dynamics of these circuits, highlighting two-stage learning, loss spikes due to high curvature or normalization layers, and the effects of initialization and curriculum learning.
