Table of Contents
Fetching ...

Clustering Head: A Visual Case Study of the Training Dynamics in Transformers

Ambroise Odonnat, Wassim Bouaziz, Vivien Cabannes

TL;DR

The paper investigates how transformers learn a controlled sparse modular addition task by visualizing training dynamics in a small model with $d=2$ embeddings. It introduces a visual sandbox and the concept of clustering heads, which realize invariances that map clustered sequence embeddings to outputs, revealing a two-phase learning process: first structure the sequence embeddings, then fit a classifier. Loss spikes are analyzed as consequences of high curvature in normalization and feed-forward layers, with gradient norms providing a diagnostic link to training stability. The work combines empirical observations with theoretical insights, showing clustering heads can arise as gradient-descent fixed points and highlighting the roles of initialization and curriculum learning in enabling robust invariants. These findings offer guidance for understanding and stabilizing training in larger transformers and motivate future formal results around two-stage learning and circuit-transferability.

Abstract

This paper introduces the sparse modular addition task and examines how transformers learn it. We focus on transformers with embeddings in $\R^2$ and introduce a visual sandbox that provides comprehensive visualizations of each layer throughout the training process. We reveal a type of circuit, called "clustering heads," which learns the problem's invariants. We analyze the training dynamics of these circuits, highlighting two-stage learning, loss spikes due to high curvature or normalization layers, and the effects of initialization and curriculum learning.

Clustering Head: A Visual Case Study of the Training Dynamics in Transformers

TL;DR

The paper investigates how transformers learn a controlled sparse modular addition task by visualizing training dynamics in a small model with embeddings. It introduces a visual sandbox and the concept of clustering heads, which realize invariances that map clustered sequence embeddings to outputs, revealing a two-phase learning process: first structure the sequence embeddings, then fit a classifier. Loss spikes are analyzed as consequences of high curvature in normalization and feed-forward layers, with gradient norms providing a diagnostic link to training stability. The work combines empirical observations with theoretical insights, showing clustering heads can arise as gradient-descent fixed points and highlighting the roles of initialization and curriculum learning in enabling robust invariants. These findings offer guidance for understanding and stabilizing training in larger transformers and motivate future formal results around two-stage learning and circuit-transferability.

Abstract

This paper introduces the sparse modular addition task and examines how transformers learn it. We focus on transformers with embeddings in and introduce a visual sandbox that provides comprehensive visualizations of each layer throughout the training process. We reveal a type of circuit, called "clustering heads," which learns the problem's invariants. We analyze the training dynamics of these circuits, highlighting two-stage learning, loss spikes due to high curvature or normalization layers, and the effects of initialization and curriculum learning.

Paper Structure

This paper contains 68 sections, 159 equations, 17 figures.

Figures (17)

  • Figure 1: Clustering Head. Implementation of an idealized circuit that captures the invariants of the problem.
  • Figure 2: Two-phase learning.Right: Loss profile featuring two significant drops in loss, marked by four red dashed lines at key snapshots. From left to right:(1) During the first snapshot, the sequence embeddings lack any clear structure. (2) They suddenly become clustered after the first loss drops, as seen in the second snapshot. (3) At this point, the MLP already classifies some clusters correctly (third snapshot). (4) A second loss drop occurs as the MLP gets fitted (last snapshot).
  • Figure 3: Connection to saddle points. From left to right: Evolution of train and test losses, the corresponding accuracies, the evolution of gradient norms for each layer in log-scale, and the similar evolution in linear scale in full-batch. We see that the learning phases of \ref{['sec:training_dyn']} appear in tandem with high gradient norms. This can be seen in the last subfigure where the three pics correspond to the loss drops and their corresponding plateaus.
  • Figure 4: High-curvature feed-forward. Loss spikes (both left) resulting from a small change from one iteration to another in sequence embeddings that are close to the decision boundaries of the subsequent feedforward layer (both right).
  • Figure 5: High-curvature normalization. Loss spikes are linked to the high curvature of internal network functions. A small update to an element $x$ can result in a substantial change to its normalized version $x/\left\| x \right\|$, significantly altering the network's subsequent behavior.
  • ...and 12 more figures

Theorems & Definitions (12)

  • Remark 3.1: Faithful to practice
  • Remark 3.2: Beyond stationary points
  • Remark 3.3: Other pathways
  • proof
  • proof
  • proof
  • proof
  • proof
  • proof
  • proof
  • ...and 2 more