Table of Contents
Fetching ...

Gated Removal of Normalization in Transformers Enables Stable Training and Efficient Inference

Andrei Kanavalau, Carmen Amo Alonso, Sanjay Lall

TL;DR

This work challenges the necessity of per-token normalization in pre-norm Transformers by introducing TaperNorm, a gated normalization layer that behaves like standard RMSNorm/LayerNorm during early training and gradually transitions to a fixed, sample-independent linear/affine map. A single global gate $g$ controls the transition, remaining at 1 during warmup to accumulate EMA statistics and cosine-decaying to 0 to enable folding of per-token statistics into adjacent projections, thereby enabling inference with norm-free layers. The authors show that output-scale anchoring is the critical factor: a final normalization provides an implicit radial anchor to prevent logit chasing, while removing normalization requires an explicit anchor provided by a fixed-target scale loss. Empirically, TaperNorm closely matches normalized baselines on pre-training and GPT-2 fine-tuning, and can yield up to $1.22\times$ faster inference when internal scalings are folded into adjacent projections, advancing toward practical norm-free Transformers.

Abstract

Normalization is widely viewed as essential for stabilizing Transformer training. We revisit this assumption for pre-norm Transformers and ask to what extent sample-dependent normalization is needed inside Transformer blocks. We introduce TaperNorm, a drop-in replacement for RMSNorm/LayerNorm that behaves exactly like the standard normalizer early in training and then smoothly tapers to a learned sample-independent linear/affine map. A single global gate is held at $g{=}1$ during gate warmup, used to calibrate the scaling branch via EMAs, and then cosine-decayed to $g{=}0$, at which point per-token statistics vanish and the resulting fixed scalings can be folded into adjacent linear projections. Our theoretical and empirical results isolate scale anchoring as the key role played by output normalization: as a (near) $0$-homogeneous map it removes radial gradients at the output, whereas without such an anchor cross-entropy encourages unbounded logit growth (``logit chasing''). We further show that a simple fixed-target auxiliary loss on the pre-logit residual-stream scale provides an explicit alternative anchor and can aid removal of the final normalization layer. Empirically, TaperNorm matches normalized baselines under identical setups while eliminating per-token statistics and enabling these layers to be folded into adjacent linear projections at inference. On an efficiency microbenchmark, folding internal scalings yields up to $1.22\times$ higher throughput in last-token logits mode. These results take a step towards norm-free Transformers while identifying the special role output normalization plays.

Gated Removal of Normalization in Transformers Enables Stable Training and Efficient Inference

TL;DR

This work challenges the necessity of per-token normalization in pre-norm Transformers by introducing TaperNorm, a gated normalization layer that behaves like standard RMSNorm/LayerNorm during early training and gradually transitions to a fixed, sample-independent linear/affine map. A single global gate controls the transition, remaining at 1 during warmup to accumulate EMA statistics and cosine-decaying to 0 to enable folding of per-token statistics into adjacent projections, thereby enabling inference with norm-free layers. The authors show that output-scale anchoring is the critical factor: a final normalization provides an implicit radial anchor to prevent logit chasing, while removing normalization requires an explicit anchor provided by a fixed-target scale loss. Empirically, TaperNorm closely matches normalized baselines on pre-training and GPT-2 fine-tuning, and can yield up to faster inference when internal scalings are folded into adjacent projections, advancing toward practical norm-free Transformers.

Abstract

Normalization is widely viewed as essential for stabilizing Transformer training. We revisit this assumption for pre-norm Transformers and ask to what extent sample-dependent normalization is needed inside Transformer blocks. We introduce TaperNorm, a drop-in replacement for RMSNorm/LayerNorm that behaves exactly like the standard normalizer early in training and then smoothly tapers to a learned sample-independent linear/affine map. A single global gate is held at during gate warmup, used to calibrate the scaling branch via EMAs, and then cosine-decayed to , at which point per-token statistics vanish and the resulting fixed scalings can be folded into adjacent linear projections. Our theoretical and empirical results isolate scale anchoring as the key role played by output normalization: as a (near) -homogeneous map it removes radial gradients at the output, whereas without such an anchor cross-entropy encourages unbounded logit growth (``logit chasing''). We further show that a simple fixed-target auxiliary loss on the pre-logit residual-stream scale provides an explicit alternative anchor and can aid removal of the final normalization layer. Empirically, TaperNorm matches normalized baselines under identical setups while eliminating per-token statistics and enabling these layers to be folded into adjacent linear projections at inference. On an efficiency microbenchmark, folding internal scalings yields up to higher throughput in last-token logits mode. These results take a step towards norm-free Transformers while identifying the special role output normalization plays.
Paper Structure (39 sections, 3 theorems, 34 equations, 4 figures, 6 tables)

This paper contains 39 sections, 3 theorems, 34 equations, 4 figures, 6 tables.

Key Result

Proposition 4.1

Let the final map $\mathrm{Norm}_{\mathrm{final}}$ before the output projection be $0$-homogeneous and differentiable almost everywhere. This includes the idealized RMSNorm and LayerNorm maps with $\varepsilon=0$. For logits $z=\mathrm{Norm}_{\mathrm{final}}(h)\,W_{\mathrm{out}}$ and any differentia

Figures (4)

  • Figure 1: TaperNorm layer and gate scheduling.
  • Figure 2: Training loss vs. step for the Baseline and Internal-Taper (+aux).
  • Figure 3: Mean logit $\ell_2$ norm vs. step. Without a scale anchor, models with the final norm removed (dashed) can exhibit logit chasing, consistent with Proposition \ref{['prop:push']}. With the fixed-target scale loss, logit growth is strongly suppressed in the All-Taper setting.
  • Figure 4: Average gradient norms across all Transformer blocks by weight type. Without explicit scale anchoring, gradients cluster primarily by presence vs. absence of the final normalization. With the fixed-target scale loss enabled, the gradient-magnitude gap between Internal-Taper and All-Taper largely disappears.

Theorems & Definitions (6)

  • Proposition 4.1: Final normalization removes radial gradient
  • Proposition 4.2: Without the final norm, cross-entropy pushes norms up
  • Proposition 4.3: Fixed-target scale loss provides a radial restoring force
  • proof
  • proof
  • proof