Table of Contents
Fetching ...

MXNorm: Reusing MXFP block scales for efficient tensor normalisation

Callum McLean, Luke Y. Prince, Alexandre Payot, Paul Balança, Carlo Luschi

Abstract

Matrix multiplication performance has long been the major bottleneck to scaling deep learning workloads, which has stimulated the design of new accelerators that use increasingly low-precision number formats. However, improvements in matrix multiplication performance have far outstripped improvements in performance on reductions and elementwise computations, which are still being performed in higher precision. In this work, we propose MXNorm, a drop-in replacement for RMSNorm that estimates the RMS using only the block scales calculated as part of the MXFP8 cast and enables a 32x decrease in the size of reduction needed for normalization. We validate our approximation method on pre-training of Llama 3 models of 125M, 1B and 8B parameters, finding minimal loss of training accuracy compared to a baseline using RMSNorm with MXFP8 matmuls. We also show practical kernel speedups using only torch.compile of up to 2.4x for MXNorm over RMSNorm, corresponding to a 1.3% speedup in Llama 3 8B transformer layers in MXFP8 and a 2.6% speedup in NVFP4.

MXNorm: Reusing MXFP block scales for efficient tensor normalisation

Abstract

Matrix multiplication performance has long been the major bottleneck to scaling deep learning workloads, which has stimulated the design of new accelerators that use increasingly low-precision number formats. However, improvements in matrix multiplication performance have far outstripped improvements in performance on reductions and elementwise computations, which are still being performed in higher precision. In this work, we propose MXNorm, a drop-in replacement for RMSNorm that estimates the RMS using only the block scales calculated as part of the MXFP8 cast and enables a 32x decrease in the size of reduction needed for normalization. We validate our approximation method on pre-training of Llama 3 models of 125M, 1B and 8B parameters, finding minimal loss of training accuracy compared to a baseline using RMSNorm with MXFP8 matmuls. We also show practical kernel speedups using only torch.compile of up to 2.4x for MXNorm over RMSNorm, corresponding to a 1.3% speedup in Llama 3 8B transformer layers in MXFP8 and a 2.6% speedup in NVFP4.
Paper Structure (33 sections, 4 theorems, 41 equations, 13 figures, 5 tables, 2 algorithms)

This paper contains 33 sections, 4 theorems, 41 equations, 13 figures, 5 tables, 2 algorithms.

Key Result

Theorem 1

Fix a block size $B$. Let $(X_i)_{i=1}^D$ be $D = KB$ i.i.d. samples from a scale family distribution such that $X = \sigma Z$, where $\sigma > 0$ and $Z$ satisfies $\mathbb{E}[Z^2] = 1$ and $\mathbb{P}(Z=0) = 0$. Partition the indices $\{1,\dots,D\}$ into $K$ disjoint blocks of size $B$, and define For $p > 0$, define the generalized $p$-mean of these block maxima by Then, as $K \to \infty$, wh

Figures (13)

  • Figure 1: Computational graphs for RMSNorm, MXCast, and MXNorm in the context of Norm + Linear layer pattern. Top left: RMSNorm + Linear graph for high precision training. Top middle: RMSNorm + Linear graph with linear inputs cast to MX (MXLinear). Top right: RMSNorm approximated with MXNorm and RMSNorm weight fused with Linear weight. Bottom left: RMSNorm graph. Bottom middle: MXCast graph for MXFP8 activations. Note that MXCast is applied to weights (E4M3 values) and gradients (E4M3 or E5M2 values) as well. Bottom right: MXNorm graph.
  • Figure 2: MXNorm as an approximation of RMSNorm. Left: MX scale distribution of normalized tensors. Middle: MX value distribution of normalized tensors. Right: MXNorm $r^2$ goodness-of-fit approaches 1 with more blocks.
  • Figure 3: Learning rate sensitivity of MXNorm compared to RMSNorm. Left: 125M parameter model (depth=8, width=1024). Right: 1B parameter model (depth=16, width=2048).
  • Figure 4: Training loss convergence of 8B parameter models trained on 300B tokens with MXNorm and RMSNorm
  • Figure 5: Example outlier feature that appears at the same step as the first loss spike in a 8B parameter training run with MXNorm using mean over block absmax to estimate RMS.
  • ...and 8 more figures

Theorems & Definitions (8)

  • Theorem 1
  • proof
  • Lemma 1: Upper bound for RMS normalization
  • proof
  • Lemma 2: Upper bound for MXNorm normalization
  • proof
  • Theorem 2
  • proof