Table of Contents
Fetching ...

Lost in Backpropagation: The LM Head is a Gradient Bottleneck

Nathan Godey, Yoav Artzi

TL;DR

It is shown the softmax bottleneck is not only an expressivity bottleneck but also an optimization bottleneck, which contributes to training inefficiencies at scale independently of the model architecture, and raises the need for new LM head designs.

Abstract

The last layer of neural language models (LMs) projects output features of dimension $D$ to logits in dimension $V$, the size of the vocabulary, where usually $D \ll V$. This mismatch is known to raise risks of limited expressivity in neural LMs, creating a so-called softmax bottleneck. We show the softmax bottleneck is not only an expressivity bottleneck but also an optimization bottleneck. Backpropagating $V$-dimensional gradients through a rank-$D$ linear layer induces unavoidable compression, which alters the training feedback provided to the vast majority of the parameters. We present a theoretical analysis of this phenomenon and measure empirically that 95-99% of the gradient norm is suppressed by the output layer, resulting in vastly suboptimal update directions. We conduct controlled pretraining experiments showing that the gradient bottleneck makes trivial patterns unlearnable, and drastically affects the training dynamics of LLMs. We argue that this inherent flaw contributes to training inefficiencies at scale independently of the model architecture, and raises the need for new LM head designs.

Lost in Backpropagation: The LM Head is a Gradient Bottleneck

TL;DR

It is shown the softmax bottleneck is not only an expressivity bottleneck but also an optimization bottleneck, which contributes to training inefficiencies at scale independently of the model architecture, and raises the need for new LM head designs.

Abstract

The last layer of neural language models (LMs) projects output features of dimension to logits in dimension , the size of the vocabulary, where usually . This mismatch is known to raise risks of limited expressivity in neural LMs, creating a so-called softmax bottleneck. We show the softmax bottleneck is not only an expressivity bottleneck but also an optimization bottleneck. Backpropagating -dimensional gradients through a rank- linear layer induces unavoidable compression, which alters the training feedback provided to the vast majority of the parameters. We present a theoretical analysis of this phenomenon and measure empirically that 95-99% of the gradient norm is suppressed by the output layer, resulting in vastly suboptimal update directions. We conduct controlled pretraining experiments showing that the gradient bottleneck makes trivial patterns unlearnable, and drastically affects the training dynamics of LLMs. We argue that this inherent flaw contributes to training inefficiencies at scale independently of the model architecture, and raises the need for new LM head designs.
Paper Structure (37 sections, 8 theorems, 34 equations, 14 figures, 4 tables)

This paper contains 37 sections, 8 theorems, 34 equations, 14 figures, 4 tables.

Key Result

Proposition 2.1

Let $\tilde{N}$ denote the row-normalized version of $N$. Then with equality if and only if $\sigma(HW^\top)=\tilde{N}$.

Figures (14)

  • Figure 1: Training dynamics for 2B models with different dimensionality constraints on the LM head. We control the rank of the output linear layer to mimic reduced hidden dimensions without changing the shape of the Transformers backbone. The results illustrate how the softmax bottleneck alone significantly reduces the convergence speed. More details can be found in \ref{['ssec:exp:llm_training']}.
  • Figure 2: Empirical rank of the logits gradients as the number of batch tokens increases. The dotted blue lines delimit the highest possible rank.
  • Figure 3: Average zero-shot scores on a subset of downstream tasks for 2B architectures stacked with $D$-rank LM heads. Average is weighted by benchmark sample size to mitigate variance.
  • Figure 4: Training curves for our SpamLang experiments (with learning rate 5e-4). As the vocabulary size increases, the same Transformer neural LM empirically struggles to learn this trivial language, while theoretically being expressive enough to learn it.
  • Figure 5: Final validation loss for a 106M (non-embedding) parameter Transformer model trained on the SpamLang synthetic language for various vocabulary sizes and learning rates. The hidden dimension is set to 576 for all models.
  • ...and 9 more figures

Theorems & Definitions (13)

  • Proposition 2.1: Gibbs inequality
  • Proposition 2.2: softmax_bottleneck
  • proof
  • Corollary 2.3: softmax_bottleneck
  • Proposition 2.4: Proof in \ref{['proof:top1']}
  • Proposition 2.5: Proof in \ref{['proof:rankPN']}
  • Proposition 2.6: Update direction residual
  • proof
  • Proposition 2.7: Proof in \ref{['proof:sgd']}
  • Corollary 2.8
  • ...and 3 more