Table of Contents
Fetching ...

Normalization Layer Per-Example Gradients are Sufficient to Predict Gradient Noise Scale in Transformers

Gavia Gray, Aman Tiwari, Shane Bergsma, Joel Hestness

Abstract

Per-example gradient norms are a vital ingredient for estimating gradient noise scale (GNS) with minimal variance. Observing the tensor contractions required to compute them, we propose a method with minimal FLOPs in 3D or greater tensor regimes by simultaneously computing the norms while computing the parameter gradients. Using this method we are able to observe the GNS of different layers at higher accuracy than previously possible. We find that the total GNS of contemporary transformer models is predicted well by the GNS of only the normalization layers. As a result, focusing only on the normalization layer, we develop a custom kernel to compute the per-example gradient norms while performing the LayerNorm backward pass with zero throughput overhead. Tracking GNS on only those layers, we are able to guide a practical batch size schedule that reduces training time by 18% on a Chinchilla-optimal language model.

Normalization Layer Per-Example Gradients are Sufficient to Predict Gradient Noise Scale in Transformers

Abstract

Per-example gradient norms are a vital ingredient for estimating gradient noise scale (GNS) with minimal variance. Observing the tensor contractions required to compute them, we propose a method with minimal FLOPs in 3D or greater tensor regimes by simultaneously computing the norms while computing the parameter gradients. Using this method we are able to observe the GNS of different layers at higher accuracy than previously possible. We find that the total GNS of contemporary transformer models is predicted well by the GNS of only the normalization layers. As a result, focusing only on the normalization layer, we develop a custom kernel to compute the per-example gradient norms while performing the LayerNorm backward pass with zero throughput overhead. Tracking GNS on only those layers, we are able to guide a practical batch size schedule that reduces training time by 18% on a Chinchilla-optimal language model.

Paper Structure

This paper contains 28 sections, 10 equations, 16 figures, 2 tables, 3 algorithms.

Figures (16)

  • Figure 1: Gradient noise scale (GNS) is typically computed by comparing per-minibatch (aggregated-across-layers) gradients to gradients "Aggregated" across minibatches. We estimate GNS with lower variance by making each minibatch a single example, and maintain per-layer GNS estimates. We find the magnitude of gradients (visualized by the length of red arrows) to be consistent across layers, enabling overall GNS to be computed very cheaply using only gradient stats from LayerNorm layers.
  • Figure 2: The variance of the gns estimator for different $B_{\textrm{big}}$ (left) and $B_{\textrm{small}}$ (right) sizes. $B_{\textrm{big}} = l$ and $B_{\textrm{small}} = s$ in legends. Stderr is estimated using a jackknife resampling method for ratio estimators choquet1999bootstrap. For the same number of samples processed, a smaller $B_{\textrm{small}}$ always has a lower standard error, while the size of the large batch, $B_{\textrm{big}}$ does not affect the standard error.
  • Figure 3: FLOP cost of computing per-example gradient norms. (Left) Total FLOP cost. (Right) Proportional cost versus one model forward and backward pass. The FLOP cost of Simultaneous per-example gradient norms is strictly dominant to alternative methods (left) and the ratio of this additional cost to the FLOP cost of processing the entire model does not depend on context length (right).
  • Figure 4: Total I/O cost of computing per-example gradient norms, assuming gradients and parameters are stored with 4 bytes of precision. The relative IO cost of Simultaneous per-example gradient norms is less than li2022large for very long contexts for all model scales, approximately equivalent for models of 10B parameters and 4096 context length, and higher for shorter contexts with larger models. The IO cost of LN (LayerNorm) per-example gradient norms alone is much lower than either method.
  • Figure 5: GNS phase plot: Linear/Embedding layers are separated from LayerNorm layers by row. Component estimators of Equations \ref{['eq:g-est']} and \ref{['eq:s-est']} are shown (left) with the GNS over the course of training on the (right).
  • ...and 11 more figures