Gated Removal of Normalization in Transformers Enables Stable Training and Efficient Inference
Andrei Kanavalau, Carmen Amo Alonso, Sanjay Lall
TL;DR
This work challenges the necessity of per-token normalization in pre-norm Transformers by introducing TaperNorm, a gated normalization layer that behaves like standard RMSNorm/LayerNorm during early training and gradually transitions to a fixed, sample-independent linear/affine map. A single global gate $g$ controls the transition, remaining at 1 during warmup to accumulate EMA statistics and cosine-decaying to 0 to enable folding of per-token statistics into adjacent projections, thereby enabling inference with norm-free layers. The authors show that output-scale anchoring is the critical factor: a final normalization provides an implicit radial anchor to prevent logit chasing, while removing normalization requires an explicit anchor provided by a fixed-target scale loss. Empirically, TaperNorm closely matches normalized baselines on pre-training and GPT-2 fine-tuning, and can yield up to $1.22\times$ faster inference when internal scalings are folded into adjacent projections, advancing toward practical norm-free Transformers.
Abstract
Normalization is widely viewed as essential for stabilizing Transformer training. We revisit this assumption for pre-norm Transformers and ask to what extent sample-dependent normalization is needed inside Transformer blocks. We introduce TaperNorm, a drop-in replacement for RMSNorm/LayerNorm that behaves exactly like the standard normalizer early in training and then smoothly tapers to a learned sample-independent linear/affine map. A single global gate is held at $g{=}1$ during gate warmup, used to calibrate the scaling branch via EMAs, and then cosine-decayed to $g{=}0$, at which point per-token statistics vanish and the resulting fixed scalings can be folded into adjacent linear projections. Our theoretical and empirical results isolate scale anchoring as the key role played by output normalization: as a (near) $0$-homogeneous map it removes radial gradients at the output, whereas without such an anchor cross-entropy encourages unbounded logit growth (``logit chasing''). We further show that a simple fixed-target auxiliary loss on the pre-logit residual-stream scale provides an explicit alternative anchor and can aid removal of the final normalization layer. Empirically, TaperNorm matches normalized baselines under identical setups while eliminating per-token statistics and enabling these layers to be folded into adjacent linear projections at inference. On an efficiency microbenchmark, folding internal scalings yields up to $1.22\times$ higher throughput in last-token logits mode. These results take a step towards norm-free Transformers while identifying the special role output normalization plays.
