Table of Contents
Fetching ...

What Happens During the Loss Plateau? Understanding Abrupt Learning in Transformers

Pulkit Gopalani, Wei Hu

TL;DR

The paper investigates abrupt learning in Transformer training on algorithmic tasks using shallow models to uncover universal plateau dynamics. It identifies three co-occurring phenomena during the plateau—partial solutions, repetition bias, and representation collapse—rooted in a bottleneck of learning the attention map. Through targeted interventions that bias the attention pattern, the authors show accelerated convergence and reduced degeneracy, and demonstrate that these dynamics generalize to early pretraining of LLMs like Pythia and OLMo. The work provides a unified view of loss-plateau behavior, linking internal representations, output statistics, and attention learning, with implications for controlling training dynamics in practical deployments.

Abstract

Training Transformers on algorithmic tasks frequently demonstrates an intriguing abrupt learning phenomenon: an extended performance plateau followed by a sudden, sharp improvement. This work investigates the underlying mechanisms for such dynamics, primarily in shallow Transformers. We reveal that during the plateau, the model often develops an interpretable partial solution while simultaneously exhibiting a strong repetition bias in their outputs. This output degeneracy is accompanied by internal representation collapse, where hidden states across different tokens become nearly parallel. We further identify the slow learning of optimal attention maps as a key bottleneck. Hidden progress in attention configuration during the plateau precedes the eventual rapid convergence, and directly intervening on attention significantly alters plateau duration and the severity of repetition bias and representational collapse. We validate that these identified phenomena-repetition bias and representation collapse-are not artifacts of toy setups but also manifest in the early pre-training stage of large language models like Pythia and OLMo.

What Happens During the Loss Plateau? Understanding Abrupt Learning in Transformers

TL;DR

The paper investigates abrupt learning in Transformer training on algorithmic tasks using shallow models to uncover universal plateau dynamics. It identifies three co-occurring phenomena during the plateau—partial solutions, repetition bias, and representation collapse—rooted in a bottleneck of learning the attention map. Through targeted interventions that bias the attention pattern, the authors show accelerated convergence and reduced degeneracy, and demonstrate that these dynamics generalize to early pretraining of LLMs like Pythia and OLMo. The work provides a unified view of loss-plateau behavior, linking internal representations, output statistics, and attention learning, with implications for controlling training dynamics in practical deployments.

Abstract

Training Transformers on algorithmic tasks frequently demonstrates an intriguing abrupt learning phenomenon: an extended performance plateau followed by a sudden, sharp improvement. This work investigates the underlying mechanisms for such dynamics, primarily in shallow Transformers. We reveal that during the plateau, the model often develops an interpretable partial solution while simultaneously exhibiting a strong repetition bias in their outputs. This output degeneracy is accompanied by internal representation collapse, where hidden states across different tokens become nearly parallel. We further identify the slow learning of optimal attention maps as a key bottleneck. Hidden progress in attention configuration during the plateau precedes the eventual rapid convergence, and directly intervening on attention significantly alters plateau duration and the severity of repetition bias and representational collapse. We validate that these identified phenomena-repetition bias and representation collapse-are not artifacts of toy setups but also manifest in the early pre-training stage of large language models like Pythia and OLMo.

Paper Structure

This paper contains 46 sections, 19 equations, 43 figures, 2 tables.

Figures (43)

  • Figure 1: Abrupt learning and related characteristics. Training a shallow Transformer on algorithmic tasks like moving-window-sum exhibits an abrupt learning curve: performance plateaus for an extended number of steps, before suddenly and sharply improving to optimum. Before the sudden drop in loss, the attention map cannot be interpreted easily, whereas the post-sudden-drop attention map is clearly interpretable w.r.t. the task. Furthermore, the model exhibits degenerate patterns before the sudden drop, including output repetitions and collapse of its hidden representations.
  • Figure 2: Abrupt learning dynamics for the MWS task. (a): Train/Test loss and Train/Test Accuracy (note that both train and test data metrics are near-identical in the online training setup, and thus we only report train metrics); (b): Attention Progress, Repetition Frequency, and Representation Cosine Similarity between hidden states. Increase in attention progress is gradual and happens before the sudden loss drop. Repetition frequency and representation cosine similarity rapidly increase at the beginning and decrease to low values later on.
  • Figure 3: Biasing attention map by $c>1.$ We find that multiplicative biasing the attention map towards more weight to optimal positions leads to faster convergence, accompanied by less repetitions and average cosine similarity.
  • Figure 4: Biasing attention map by $c<1.$ We find that biasing the attention map to have lesser weight at optimal positions leads to slower convergence, and more representation collapse and repetitions.
  • Figure 5: Different optimal initializations and effect on training. We find that fixing attention and embedding weights (i.e. attention map) to optimal value, and training other components leads to faster convergence and lesser representation collapse / repetitions. Similar effect does not hold for fixing optimal MLP or Embeddings. ($K, Q, O, V$ respectively denote the parameters $W_K, W_Q, W_O, W_V.$)
  • ...and 38 more figures