Table of Contents
Fetching ...

Heavy-Tailed Class Imbalance and Why Adam Outperforms Gradient Descent on Language Models

Frederik Kunstner, Robin Yadav, Alan Milligan, Mark Schmidt, Alberto Bietti

TL;DR

Language-model optimization exhibits heavy-tailed class distributions that exacerbate SGD's progress relative to Adam, particularly for rare tokens. The paper combines large-scale experiments with a softmax-linear toy model and a continuous-time analysis to isolate the effect of class imbalance on optimization dynamics, showing that SGD slows on low-frequency classes (e.g., with $\pi_k \propto 1/k$) while Adam and sign-based methods progress across all classes. A gradient-Hessian assignment mechanism emerges during training, yielding correlated per-class blocks that favor Adam-like updates; a simple imbalanced setting yields $\ell_k(t) = \Theta\left(\frac{1}{\pi_k t}\right)$ under gradient flow, versus $\ell_k(t) = \Theta\left(e^{-ct}\right)$ under continuous-time sign descent, independent of $\pi_k$. These results justify Adam-like optimizers for long-tailed language tasks and inform practical strategies for imbalance-aware training and optimizer design, including upweighting schemes and tokenization considerations.

Abstract

Adam has been shown to outperform gradient descent on large language models by a larger margin than on other tasks, but it is unclear why. We show that a key factor in this performance gap is the heavy-tailed class imbalance found in language tasks. When trained with gradient descent, the loss of infrequent words decreases more slowly than the loss of frequent ones. This leads to a slow decrease on the average loss as most samples come from infrequent words. On the other hand, Adam and sign-based methods are less sensitive to this problem. To establish that this behavior is caused by class imbalance, we show empirically that it can be reproduced across architectures and data types, on language transformers, vision CNNs, and linear models. On a linear model with cross-entropy loss, we show that class imbalance leads to imbalanced, correlated gradients and Hessians that have been hypothesized to benefit Adam. We also prove that, in continuous time, gradient descent converges slowly on low-frequency classes while sign descent does not.

Heavy-Tailed Class Imbalance and Why Adam Outperforms Gradient Descent on Language Models

TL;DR

Language-model optimization exhibits heavy-tailed class distributions that exacerbate SGD's progress relative to Adam, particularly for rare tokens. The paper combines large-scale experiments with a softmax-linear toy model and a continuous-time analysis to isolate the effect of class imbalance on optimization dynamics, showing that SGD slows on low-frequency classes (e.g., with ) while Adam and sign-based methods progress across all classes. A gradient-Hessian assignment mechanism emerges during training, yielding correlated per-class blocks that favor Adam-like updates; a simple imbalanced setting yields under gradient flow, versus under continuous-time sign descent, independent of . These results justify Adam-like optimizers for long-tailed language tasks and inform practical strategies for imbalance-aware training and optimizer design, including upweighting schemes and tokenization considerations.

Abstract

Adam has been shown to outperform gradient descent on large language models by a larger margin than on other tasks, but it is unclear why. We show that a key factor in this performance gap is the heavy-tailed class imbalance found in language tasks. When trained with gradient descent, the loss of infrequent words decreases more slowly than the loss of frequent ones. This leads to a slow decrease on the average loss as most samples come from infrequent words. On the other hand, Adam and sign-based methods are less sensitive to this problem. To establish that this behavior is caused by class imbalance, we show empirically that it can be reproduced across architectures and data types, on language transformers, vision CNNs, and linear models. On a linear model with cross-entropy loss, we show that class imbalance leads to imbalanced, correlated gradients and Hessians that have been hypothesized to benefit Adam. We also prove that, in continuous time, gradient descent converges slowly on low-frequency classes while sign descent does not.
Paper Structure (33 sections, 8 theorems, 45 equations, 26 figures, 1 table)

This paper contains 33 sections, 8 theorems, 45 equations, 26 figures, 1 table.

Key Result

Proposition 1

If initialized at ${\mathbf{W}}_0 = 0$, the gradient and Hessian of the loss $\mathcal{L}$ w.r.t. ${\mathbf{w}}_k$ are During training, if the model correctly assigns samples to class $k$ with probability $p$ (ass:correct), for classes where the frequency does not vanish too quickly, $\pi_k = \omega(1/c)$.

Figures (26)

  • Figure 1: Gradient descent does not make progress on low-frequency classes, while Adam does. Training GPT2-Small on WikiText-103. (a) Distribution of the classes sorted by class frequency, split into groups corresponding to ${\approx}10\%$ of the data. (b) Overall training loss. (c, d) Training loss for each group using SGD and Adam. SGD makes little to no progress on low-frequency classes while Adam makes progress on all groups. (b) is the average of (c, d) for the respective optimizer.
  • Figure 2: Adam outperforms SGD for training a CNN under heavy-tailed class labels.(a) Performance on the MNIST dataset. (b) Performance on a modified MNIST with two groups of classes. The first group consists of the 10 original classes with ${\approx5}$k samples each, while the second consists of ${\approx}10$k added classes with $5$ examples each. (c, d) Performance of GD and Adam on the two groups.
  • Figure 3: Adam outperforms SGD for training a ResNet under heavy-tailed class labels.(a) Performance on a subset of ImageNet and (b) an imbalanced subset of ImageNet with class frequencies $\pi_k \propto 1/k$. (c, d) Performance of GD and Adam on groups corresponding to ${\approx}10\%$ of the data.
  • Figure 4: The impact of heavy-tailed class imbalance is reproducible with linear models. Softmax regression on synthetic data. The inputs are drawn from a uniform distribution on $[0,1]^d$. The target classes are heavy-tailed (a) and independent of the inputs, but the model can still fit the data as it is overparameterized. (b, c, d) Overall training loss and performance of GD and Adam on each subset.
  • Figure 5: Sign descent, as a simplified form of Adam, performs well on low-frequency classes. Training the last layer of a simplified one-layer transformer with GD, Adam, normalized GD, and sign descent, with and without momentum ($\pm$m). Momentum and normalizing the magnitude help but have smaller effects than using sign descent, which recovers similar dynamics to Adam.
  • ...and 21 more figures

Theorems & Definitions (13)

  • Proposition 1
  • proof : Proof idea
  • Theorem 2
  • Proposition 2
  • proof : Proof of \ref{['prop:assignment']}
  • Theorem 2
  • Lemma 3: Separation of the dynamics
  • proof
  • Lemma 4: Solution of the dynamics
  • proof
  • ...and 3 more