Table of Contents
Fetching ...

Hidden Breakthroughs in Language Model Training

Sara Kangaslahti, Elan Rosenfeld, Naomi Saphra

TL;DR

This paper introduces POLCA, a method for decomposing changes in loss along arbitrary bases of the low-rank training subspace and uses the method to identify clusters of samples that share similar changes in loss during training, disaggregating the overall loss into that of smaller groups of conceptually similar data.

Abstract

Loss curves are smooth during most of model training, so visible discontinuities stand out as possible conceptual breakthroughs. Studying these breakthroughs enables a deeper understanding of learning dynamics, but only when they are properly identified. This paper argues that similar breakthroughs occur frequently throughout training but they are obscured by a loss metric that collapses all variation into a single scalar. To find these hidden transitions, we introduce POLCA, a method for decomposing changes in loss along arbitrary bases of the low-rank training subspace. We use our method to identify clusters of samples that share similar changes in loss during training, disaggregating the overall loss into that of smaller groups of conceptually similar data. We validate our method on synthetic arithmetic and natural language tasks, showing that POLCA recovers clusters that represent interpretable breakthroughs in the model's capabilities. We demonstrate the promise of these hidden phase transitions as a tool for unsupervised interpretability.

Hidden Breakthroughs in Language Model Training

TL;DR

This paper introduces POLCA, a method for decomposing changes in loss along arbitrary bases of the low-rank training subspace and uses the method to identify clusters of samples that share similar changes in loss during training, disaggregating the overall loss into that of smaller groups of conceptually similar data.

Abstract

Loss curves are smooth during most of model training, so visible discontinuities stand out as possible conceptual breakthroughs. Studying these breakthroughs enables a deeper understanding of learning dynamics, but only when they are properly identified. This paper argues that similar breakthroughs occur frequently throughout training but they are obscured by a loss metric that collapses all variation into a single scalar. To find these hidden transitions, we introduce POLCA, a method for decomposing changes in loss along arbitrary bases of the low-rank training subspace. We use our method to identify clusters of samples that share similar changes in loss during training, disaggregating the overall loss into that of smaller groups of conceptually similar data. We validate our method on synthetic arithmetic and natural language tasks, showing that POLCA recovers clusters that represent interpretable breakthroughs in the model's capabilities. We demonstrate the promise of these hidden phase transitions as a tool for unsupervised interpretability.

Paper Structure

This paper contains 41 sections, 10 equations, 11 figures, 12 tables, 1 algorithm.

Figures (11)

  • Figure 1: A smooth loss function may change sharply for a particular direction or data subset. POLCA works by decomposing and disaggregating the loss to discover these sharp changes. Left: Loss $L(x; \theta)$ changes as the parameter setting $\theta$ moves in a low-rank training subspace. The loss is sigmoidal on each axis, with differently timed inflections along basis vectors $\textcolor{Lavender}{b_1}$ and $\textcolor{YellowOrange}{b_2}$. These breakthroughs disappear in the smooth sum of the sigmoids which represents the exact loss. Right: The average of sigmoidal functions---including loss along basis vectors $\textcolor{Lavender}{b_1}$ and $\textcolor{YellowOrange}{b_2}$---elides individual breakthroughs. The more differently-timed breakthroughs underlie the loss, the more hidden each breakthrough is.
  • Figure 2: Diagram of arithmetic addition task. An example of 3-digit addition, labeled with the skills required for each of the output tokens.
  • Figure 3: Exact loss trajectory clustering on the arithmetic task. We use HDBSCAN to cluster the exact loss trajectories. This approach, unlike our POLCA clustering method, fails to recover clusters associated with the carrying skill (the maximum fraction of carries is 0.51).
  • Figure 4: Arithmetic data clusters with POLCA. We perform POLCA clustering on the top 2 basis vectors, and report the cluster medoid and quartiles (left), median exact loss (center), and cluster skill composition (right) for each basis vector in order. Vertical lines mark the timestep when the relevant basis vector was sampled; note that a vector's phase transitions are not directly associated with this timestep. We find that the first basis vector recovers the digit skill whereas the second basis vector recovers the carrying skill (cluster #1 has homogeneity 0.90). The clusters computed from the POLCA trajectories show changes in the decomposed loss that are obscured in the exact loss curves.
  • Figure 5: Examples of English LM data clusters with POLCA. After clustering on POLCA trajectories for two illustrative basis vectors, we report their average decomposed POLCA trajectories (\ref{['fig:english:vec13']}(\ref{['fig:mean_lang_polca_13']}) and \ref{['fig:english:vec23']}(\ref{['fig:mean_lang_polca_23']})). Figures \ref{['fig:english:vec13']}(\ref{['fig:mean_lang_polca_13_loss']}) and \ref{['fig:english:vec23']}(\ref{['fig:mean_lang_polca_23_loss']}) show the average of the exact loss trajectories for each of the POLCA trajectory clusters. For each cluster, we provide a label based on the top POS tags and tokens in the cluster and the top 10 contexts closest to its medoid. We report the 3 contexts closest to the cluster medoid and color the corresponding token. Clustering on the decomposed POLCA trajectories reveals low-rank breakthroughs at times when the full-rank exact loss curve remains smooth.
  • ...and 6 more figures