Table of Contents
Fetching ...

Impact of Layer Norm on Memorization and Generalization in Transformers

Rishi Singhal, Jung-Eun Kim

TL;DR

This work uncovers a dual role for LayerNorm in transformers that depends on whether Pre-LayerNorm or Post-LayerNorm is used. By removing the learnable LN parameters, the authors show that LN is essential for stable learning in Pre-LN models, whereas LN governs memorization in Post-LN models, where LN removal reduces memorization and aids recovery of true labels. A gradient-based analysis reveals that learning gradients consistently dominate memorization gradients, with early LN layers exhibiting the strongest influence in both configurations. These findings hold across 13 models spanning vision and language tasks, providing mechanistic insights that can guide robust training under noisy labels and inform architectural choices in transformer design.

Abstract

Layer Normalization (LayerNorm) is one of the fundamental components in transformers that stabilizes training and improves optimization. In recent times, Pre-LayerNorm transformers have become the preferred choice over Post-LayerNorm transformers due to their stable gradient flow. However, the impact of LayerNorm on learning and memorization across these architectures remains unclear. In this work, we investigate how LayerNorm influences memorization and learning for Pre- and Post-LayerNorm transformers. We identify that LayerNorm serves as a key factor for stable learning in Pre-LayerNorm transformers, while in Post-LayerNorm transformers, it impacts memorization. Our analysis reveals that eliminating LayerNorm parameters in Pre-LayerNorm models exacerbates memorization and destabilizes learning, while in Post-LayerNorm models, it effectively mitigates memorization by restoring genuine labels. We further precisely identify that early layers LayerNorm are the most critical over middle/later layers and their influence varies across Pre and Post LayerNorm models. We have validated it through 13 models across 6 Vision and Language datasets. These insights shed new light on the role of LayerNorm in shaping memorization and learning in transformers.

Impact of Layer Norm on Memorization and Generalization in Transformers

TL;DR

This work uncovers a dual role for LayerNorm in transformers that depends on whether Pre-LayerNorm or Post-LayerNorm is used. By removing the learnable LN parameters, the authors show that LN is essential for stable learning in Pre-LN models, whereas LN governs memorization in Post-LN models, where LN removal reduces memorization and aids recovery of true labels. A gradient-based analysis reveals that learning gradients consistently dominate memorization gradients, with early LN layers exhibiting the strongest influence in both configurations. These findings hold across 13 models spanning vision and language tasks, providing mechanistic insights that can guide robust training under noisy labels and inform architectural choices in transformer design.

Abstract

Layer Normalization (LayerNorm) is one of the fundamental components in transformers that stabilizes training and improves optimization. In recent times, Pre-LayerNorm transformers have become the preferred choice over Post-LayerNorm transformers due to their stable gradient flow. However, the impact of LayerNorm on learning and memorization across these architectures remains unclear. In this work, we investigate how LayerNorm influences memorization and learning for Pre- and Post-LayerNorm transformers. We identify that LayerNorm serves as a key factor for stable learning in Pre-LayerNorm transformers, while in Post-LayerNorm transformers, it impacts memorization. Our analysis reveals that eliminating LayerNorm parameters in Pre-LayerNorm models exacerbates memorization and destabilizes learning, while in Post-LayerNorm models, it effectively mitigates memorization by restoring genuine labels. We further precisely identify that early layers LayerNorm are the most critical over middle/later layers and their influence varies across Pre and Post LayerNorm models. We have validated it through 13 models across 6 Vision and Language datasets. These insights shed new light on the role of LayerNorm in shaping memorization and learning in transformers.

Paper Structure

This paper contains 64 sections, 3 theorems, 113 equations, 35 figures, 6 tables.

Key Result

Theorem 1

It is formally represented as follows:

Figures (35)

  • Figure 1: Impact of LN layer on memorization and learning of Pre- and Post-LN models. (a) shows a clear impact of LN in Pre-LN models, whereas (c) shows no impact of the removal of LN parameters in Post-LN models for learning. (b) exhibits that, without LN layers, the Pre-LN models struggle with high memorization and random predictions (red-color-family bars), while (d) exhibits that in Post-LN models, removing LN parameters recovers a significant portion of correct predictions (green bars).
  • Figure 2: LN removal destabilizes learning in Pre-LN models, while mitigates memorization in Post-LN models (News Dataset): LN removal in Pre-LN models critically affects learning (accuracy gap in (a)) while Post-LN models remain robust (negligible gap in (d)); LN removal helps in effective mitigation of memorization and high recovery in Post-LN models (green bars in (e)), while memorization/random predictions still persist in Pre-LN models (red-color-family bars in (e)); LN removal in Pre-LN models exacerbates overfitting explained by increasing train-test accuracy gap in (c), and for Post-LN models it decreases due to memorization mitigation (see (f)).
  • Figure 3: Pivotal impact of early LNs for learning and memorization across Pre- and Post-LN models. (a) clearly shows impact of early LNs removal on destabilizing learning in Pre-LN models, accompanied with higher train-test-accuracy gap, $\Delta_{\text{overfit}}^{\text{Pre, early}}$, than later layers, whereas (b) shows early LNs removal help in suppressing memorization and improving recovery in Post-LN models, alongwith lower train-test-accuracy gap, $\Delta_{\text{overfit}}^{\text{Post, early}}$, than later layers.
  • Figure 4: Learning vs. Memorization Gradients in Pre- and Post-LN Models: (in Emotions Dataset) Results clearly exhibit high gradient norms of early layers LNs than later layers for both learning and memorization in Pre-LN (GPTNeo) and Post-LN (DeBERTa) models. Importantly, the learning gradient norm ($\left\lVert g_x^{\text{learn}} \right\rVert_2$) is consistently stronger than the memorization gradient norm ($\left\lVert g_x^{\text{mem}} \right\rVert_2$) across all layers. Furthermore, the ratio $\left\lVert g_x^{\text{learn}} \right\rVert_2 / \left\lVert g_x^{\text{mem}} \right\rVert_2$ is significantly higher in Pre-LN models compared to Post-LN models.
  • Figure 5: Pre-LN vs. Post-LN architectures depicting LN placement and categorization of early, middle and later layers
  • ...and 30 more figures

Theorems & Definitions (3)

  • Theorem 1: Learning Gradient Norm, $\|\boldsymbol{g}_{\boldsymbol{x}}^{\boldsymbol{\text{learn}}}\|_2$ is greater than or equal to Memorization Gradient Norm, $\|\boldsymbol{g}_{\boldsymbol{x}}^{\boldsymbol{\text{mem}}}\|_2$ across all layers
  • Theorem 2: Gradient norm of loss $\mathcal{L}$ w.r.t input of LN is upper bounded
  • Theorem 3: Upper bound of the gradient norm of Early Layers LN are higher than those of Later layers LN