Table of Contents
Fetching ...

Understanding and Minimising Outlier Features in Neural Network Training

Bobby He, Lorenzo Noci, Daniele Paliotta, Imanol Schlag, Thomas Hofmann

TL;DR

A novel unnormalised transformer block is introduced, the Outlier Protected block is introduced, and a previously unknown benefit of non-diagonal preconditioning optimisers is presented, finding both approaches to significantly reduce OFs and improve quantisation without compromising convergence speed, at scales of up to 7B parameters.

Abstract

Outlier Features (OFs) are neurons whose activation magnitudes significantly exceed the average over a neural network's (NN) width. They are well known to emerge during standard transformer training and have the undesirable effect of hindering quantisation in afflicted models. Despite their practical importance, little is known behind why OFs emerge during training, nor how one can minimise them. Our work focuses on the above questions, first identifying several quantitative metrics, such as the kurtosis over neuron activation norms, to measure OFs. With these metrics, we study how architectural and optimisation choices influence OFs, and provide practical insights to minimise OFs during training. As highlights, we introduce a novel unnormalised transformer block, the Outlier Protected block, and present a previously unknown benefit of non-diagonal preconditioning optimisers, finding both approaches to significantly reduce OFs and improve quantisation without compromising convergence speed, at scales of up to 7B parameters. Notably, our combination of OP block and non-diagonal preconditioner (SOAP) achieves 14.87 int8 weight-and-activation perplexity (from 14.71 in standard precision), compared to 63.4 int8 perplexity (from 16.00) with a default OF-prone combination of Pre-Norm model and Adam, when quantising OPT-125m models post-training. Overall, our findings shed new light on our understanding of, our ability to prevent, and the complexity of this important aspect of NN training dynamics.

Understanding and Minimising Outlier Features in Neural Network Training

TL;DR

A novel unnormalised transformer block is introduced, the Outlier Protected block is introduced, and a previously unknown benefit of non-diagonal preconditioning optimisers is presented, finding both approaches to significantly reduce OFs and improve quantisation without compromising convergence speed, at scales of up to 7B parameters.

Abstract

Outlier Features (OFs) are neurons whose activation magnitudes significantly exceed the average over a neural network's (NN) width. They are well known to emerge during standard transformer training and have the undesirable effect of hindering quantisation in afflicted models. Despite their practical importance, little is known behind why OFs emerge during training, nor how one can minimise them. Our work focuses on the above questions, first identifying several quantitative metrics, such as the kurtosis over neuron activation norms, to measure OFs. With these metrics, we study how architectural and optimisation choices influence OFs, and provide practical insights to minimise OFs during training. As highlights, we introduce a novel unnormalised transformer block, the Outlier Protected block, and present a previously unknown benefit of non-diagonal preconditioning optimisers, finding both approaches to significantly reduce OFs and improve quantisation without compromising convergence speed, at scales of up to 7B parameters. Notably, our combination of OP block and non-diagonal preconditioner (SOAP) achieves 14.87 int8 weight-and-activation perplexity (from 14.71 in standard precision), compared to 63.4 int8 perplexity (from 16.00) with a default OF-prone combination of Pre-Norm model and Adam, when quantising OPT-125m models post-training. Overall, our findings shed new light on our understanding of, our ability to prevent, and the complexity of this important aspect of NN training dynamics.
Paper Structure (51 sections, 1 theorem, 17 equations, 53 figures, 4 tables)

This paper contains 51 sections, 1 theorem, 17 equations, 53 figures, 4 tables.

Key Result

Proposition G.1

Suppose we have ${\mathbf{X}}\in\mathbb{R}^{n\times d}$ zero-mean Gaussian distributed with all inputs uniformly correlated with some $\rho>0$, and independent features (across columns). That is: $\mathbb{E}[{\mathbf{X}}]=\mathbf{0}$ and $\mathbb{E}[{\mathbf{X}}_{\alpha,j} {\mathbf{X}}_{\beta,k}]= \

Figures (53)

  • Figure 1: Outlier Features appear in open-source transformers biderman2023pythia during training, as measured by our Kurtosis metric \ref{['eq:kurt']}. Our work investigates the design choices that influence their emergence.
  • Figure 2: Kurtosis becomes large (i.e. OFE) with different Norms at 130M scale. We plot the residual stream entering the 2nd of 6 blocks. Other layers in \ref{['fig:kurt_all_layers_codeparrot']}.
  • Figure 3: The Outlier Protected Transformer Block. We remove Pre-Norms and replace them with an Entropy Regulation mechanism to prevent entropy collapse, as well as downscaling residuals with $\beta<1$.
  • Figure 4: Our OP block mitigates OFE. We plot activation kurtosis of the residual stream across layers. Experiments are at 1.2B scale on Languini Books using a max AdamW learning rate of $0.001$ with linear warmup for the first 1.5% steps and linear decay thereafter. Notice the shared log-scaled y-axis: activation kurtosis is consistently (up to 4 orders of magnitude) lower in OP block, particularly in earlier layers. Also, peak kurtosis during training is always higher in Pre-LN. The OP model also removes the final LN before unembedding; the effect of the final LN on OFE is shown in \ref{['fig:op_sigprop']}.
  • Figure 5: Adam-trained Pre-LN layers at 1.2B scale with extreme OFE (left) are those with bad Signal Prop close to rank collapse during training (centre left). (Right vs. left two plots) Downweighting residual branches improves signal propagation during training and results in smaller OFs, particularly in early layers. Respective plots for OP (with & without final LN before unembedding) in \ref{['fig:op_sigprop']}.
  • ...and 48 more figures

Theorems & Definitions (2)

  • Proposition G.1: Bad Signal Propagation implies higher kurtosis for Gaussian features
  • proof