Table of Contents
Fetching ...

SWAN: SGD with Normalization and Whitening Enables Stateless LLM Training

Chao Ma, Wenbo Gong, Meyer Scetbon, Edward Meeds

TL;DR

SWAN introduces SGD with Whitening And Normalization, a stateless gradient pre-processing pipeline for LLM training that replaces traditional stateful optimizers like Adam. By applying GradNorm to stabilize gradient distributions and GradWhitening to orthogonalize gradient directions, SWAN eliminates optimizer-state storage while preserving or improving learning efficiency. Empirical results on memory-efficient LLaMA pre-training across multiple sizes show SWAN matching or surpassing Adam in perplexity, with substantial memory savings (≈$50\%$ total memory) and up to $2\times$ speedups in tokens seen; a fast NSDS variant further mirrors Adam throughput without distributed gradient pre-processing. These findings demonstrate the viability of stateless optimization for scalable, memory-constrained LLM training and motivate broader exploration of gradient-normalization pipelines. The work also provides theoretical insights linking GradNorm to gradient-covariance stabilization and GradWhitening to non-diagonal second-order updates under plausible Hessian structures, supporting the practical performance observed.

Abstract

Adaptive optimizers such as Adam (Kingma & Ba, 2015) have been central to the success of large language models. However, they often require to maintain optimizer states throughout training, which can result in memory requirements several times greater than the model footprint. This overhead imposes constraints on scalability and computational efficiency. Stochastic Gradient Descent (SGD), in contrast, is a stateless optimizer, as it does not track state variables during training. Consequently, it achieves optimal memory efficiency. However, its capability in LLM training is limited (Zhao et al., 2024b). In this work, we show that pre-processing SGD in a stateless manner can achieve the same performance as the Adam optimizer for LLM training, while drastically reducing the memory cost. Specifically, we propose to pre-process the instantaneous stochastic gradients using normalization and whitening. We show that normalization stabilizes gradient distributions, and whitening counteracts the local curvature of the loss landscape. This results in SWAN (SGD with Whitening And Normalization), a stochastic optimizer that eliminates the need to store any optimizer states. Empirically, SWAN has the same memory footprint as SGD, achieving $\approx 50\%$ reduction on total end-to-end memory compared to Adam. In language modeling tasks, SWAN demonstrates comparable or even better performance than Adam: when pre-training the LLaMA model with 350M and 1.3B parameters, SWAN achieves a 2x speedup by reaching the same evaluation perplexity using half as many tokens.

SWAN: SGD with Normalization and Whitening Enables Stateless LLM Training

TL;DR

SWAN introduces SGD with Whitening And Normalization, a stateless gradient pre-processing pipeline for LLM training that replaces traditional stateful optimizers like Adam. By applying GradNorm to stabilize gradient distributions and GradWhitening to orthogonalize gradient directions, SWAN eliminates optimizer-state storage while preserving or improving learning efficiency. Empirical results on memory-efficient LLaMA pre-training across multiple sizes show SWAN matching or surpassing Adam in perplexity, with substantial memory savings (≈ total memory) and up to speedups in tokens seen; a fast NSDS variant further mirrors Adam throughput without distributed gradient pre-processing. These findings demonstrate the viability of stateless optimization for scalable, memory-constrained LLM training and motivate broader exploration of gradient-normalization pipelines. The work also provides theoretical insights linking GradNorm to gradient-covariance stabilization and GradWhitening to non-diagonal second-order updates under plausible Hessian structures, supporting the practical performance observed.

Abstract

Adaptive optimizers such as Adam (Kingma & Ba, 2015) have been central to the success of large language models. However, they often require to maintain optimizer states throughout training, which can result in memory requirements several times greater than the model footprint. This overhead imposes constraints on scalability and computational efficiency. Stochastic Gradient Descent (SGD), in contrast, is a stateless optimizer, as it does not track state variables during training. Consequently, it achieves optimal memory efficiency. However, its capability in LLM training is limited (Zhao et al., 2024b). In this work, we show that pre-processing SGD in a stateless manner can achieve the same performance as the Adam optimizer for LLM training, while drastically reducing the memory cost. Specifically, we propose to pre-process the instantaneous stochastic gradients using normalization and whitening. We show that normalization stabilizes gradient distributions, and whitening counteracts the local curvature of the loss landscape. This results in SWAN (SGD with Whitening And Normalization), a stochastic optimizer that eliminates the need to store any optimizer states. Empirically, SWAN has the same memory footprint as SGD, achieving reduction on total end-to-end memory compared to Adam. In language modeling tasks, SWAN demonstrates comparable or even better performance than Adam: when pre-training the LLaMA model with 350M and 1.3B parameters, SWAN achieves a 2x speedup by reaching the same evaluation perplexity using half as many tokens.

Paper Structure

This paper contains 84 sections, 9 theorems, 71 equations, 14 figures, 6 tables, 2 algorithms.

Key Result

Theorem 1

Consider the STB (def: STB). Assuming we inherit the assumptions in Theorem 1 of tian2023joma, as described in app: discussion. Then consider ${\mathbf{U}}_{C}^\top{\mathbf{W}}$, the composition of the MLP project-up matrix and the embedding matrix as a whole. Then, its standardized stochastic gradi In other words, the covariance structure of $\tilde{{{\mathbf{G}}}}$ is identical across all time s

Figures (14)

  • Figure 2: Illustration of $\mathtt{GradNorm}$ and $\mathtt{GradWhitening}$ operators. In $\mathtt{GradNorm}$ operator, we perform standardization across the output dimensions (columns), using statistics computed row-wise. In $\mathtt{GradWhitening}$ operator (illustration adapted from huang2019iterative), we treat each column of the gradient matrix ${\mathbf{G}}$ as a separate data sample. Then, $\mathtt{GradWhitening}$ can be seen as stretching/squeezing the data such that the covariance matrix is the identity across all eigen directions.
  • Figure 3: SWAN Optimizer
  • Figure 4: Comparison of convergence rate of different methods on LLM pretraining tasks. The training curves of Adam, Galore and Apollo-mini are reproduced according to the opensource code of Zhao2024GaLoreMLzhu2024apollo. We further compare to their official results in \ref{['tab: llm']}.
  • Figure 5: Comparative analysis of SWAN and Adam optimizers: speedup ratios and perplexity metrics across various model sizes. (a) shows how SWAN reduces the number of training steps needed to achieve the same evaluation perplexity as Adam for models ranging from 60M to 1.3B parameters. A speedup ratio greater than one indicates that SWAN reaches target PPL values faster than Adam. (b) presents a direct comparison of perplexity scores between SWAN and Adam. In both plots, we also provide counterfactual additive curves (dashed lines) modeling baselines corresponding to constant step advantages. Together, these plots highlight the nature of SWAN's speedup over Adam across different model scales.
  • Figure 6: Ablation studies on 130M model. (a) Ablation on the contribution of each components in SWAN and Adam. (b) Ablation on removing $\mathtt{GradNorm}$ and compensate with larger learning rates. (c) Ablation on the effect of learning rate warm-ups.
  • ...and 9 more figures

Theorems & Definitions (18)

  • Definition 1: Simplified Transformer Block (STB)
  • Theorem 1: $\mathtt{GradNorm}$ stabilizes gradient distributions across time for the STB
  • Proposition 1: Shared structures in the block-diagonal of Hessians at transformer equilibrium
  • Theorem 2: Contraction factor of $\mathtt{GradWhitening}$
  • Proposition 2: Robustness of $\mathtt{GradWhitening}$ update convergence rate against the condition number of local Hessian
  • Proposition 3: $\mathtt{GradWhitening}$ with single lr vs Adam with tuned group lr
  • proof
  • proof
  • proof
  • Theorem 3: Contraction factor lower bound for gradient descent, generalized based on zhang2024transformers
  • ...and 8 more