Table of Contents
Fetching ...

A Stable Whitening Optimizer for Efficient Neural Network Training

Kevin Frans, Sergey Levine, Pieter Abbeel

TL;DR

A central problem in neural network optimization is achieving fast, stable progress with second-order-like updates at scale. The paper presents SPlus, a stable whitening optimizer that builds on Shampoo by incorporating instant-sign normalization, shape-aware symmetric scaling, and iterate averaging to address divergence, width-related learning-rate transfer, and parameter noise. Empirical results on Transformer-based tasks across language modeling, diffusion, and image classification show SPlus consistently outperforms or matches strong baselines (including Adam) in both gradient steps and wallclock time, with a practical implementation that minimizes extra passes. The work offers a scalable, reproducible protocol and release of code, contributing a robust alternative for fast, stable training of large neural networks. The improvements have practical impact by enabling faster convergence without extensive hyperparameter tuning, potentially accelerating training of large-scale models.

Abstract

In this work, we take an experimentally grounded look at neural network optimization. Building on the Shampoo family of algorithms, we identify and alleviate three key issues, resulting in the proposed SPlus method. First, we find that naive Shampoo is prone to divergence when matrix-inverses are cached for long periods. We introduce an alternate bounded update combining a historical eigenbasis with instantaneous normalization, resulting in across-the-board stability and significantly lower computational requirements. Second, we adapt a shape-aware scaling to enable learning rate transfer across network width. Third, we find that high learning rates result in large parameter noise, and propose a simple iterate-averaging scheme which unblocks faster learning. To properly confirm these findings, we introduce a pointed Transformer training benchmark, considering three objectives (language modelling, image classification, and diffusion modelling) across different stages of training. On average, SPlus is able to reach the validation performance of Adam within 44-58% of the gradient steps and 62-83% of the wallclock time.

A Stable Whitening Optimizer for Efficient Neural Network Training

TL;DR

A central problem in neural network optimization is achieving fast, stable progress with second-order-like updates at scale. The paper presents SPlus, a stable whitening optimizer that builds on Shampoo by incorporating instant-sign normalization, shape-aware symmetric scaling, and iterate averaging to address divergence, width-related learning-rate transfer, and parameter noise. Empirical results on Transformer-based tasks across language modeling, diffusion, and image classification show SPlus consistently outperforms or matches strong baselines (including Adam) in both gradient steps and wallclock time, with a practical implementation that minimizes extra passes. The work offers a scalable, reproducible protocol and release of code, contributing a robust alternative for fast, stable training of large neural networks. The improvements have practical impact by enabling faster convergence without extensive hyperparameter tuning, potentially accelerating training of large-scale models.

Abstract

In this work, we take an experimentally grounded look at neural network optimization. Building on the Shampoo family of algorithms, we identify and alleviate three key issues, resulting in the proposed SPlus method. First, we find that naive Shampoo is prone to divergence when matrix-inverses are cached for long periods. We introduce an alternate bounded update combining a historical eigenbasis with instantaneous normalization, resulting in across-the-board stability and significantly lower computational requirements. Second, we adapt a shape-aware scaling to enable learning rate transfer across network width. Third, we find that high learning rates result in large parameter noise, and propose a simple iterate-averaging scheme which unblocks faster learning. To properly confirm these findings, we introduce a pointed Transformer training benchmark, considering three objectives (language modelling, image classification, and diffusion modelling) across different stages of training. On average, SPlus is able to reach the validation performance of Adam within 44-58% of the gradient steps and 62-83% of the wallclock time.

Paper Structure

This paper contains 16 sections, 25 equations, 7 figures, 4 tables, 2 algorithms.

Figures (7)

  • Figure 1: Whitening normalizes gradients to have uniform magnitude along each axis of descent. This decouples the updates from gradient magnitude. Elementwise whitening imposes an independent axis per dimension, whereas full whitening uses the axes that maximally explain gradient covariance.
  • Figure 2: Shampoo is prone to divergence, but SPlus remains stable under the same settings. Plotted above are loss curves on language modelling, sweeping over learning rate between $(0.0001, 0.000215, 0.000464, 0.001)$ and cache duration between $(5,10,25,100,500)$. SPlus is significantly more robust to hyperparameters than Shampoo. This robustness is crucial for improving practical training speed -- in our setting. Shampoo diverges when caching for $>100$ gradient steps while SPlus remains stable, enabling a faster wall-clock performance than Adam.
  • Figure 3: Optimal learning rates for SPlus transfer across network widths. This is achieved by normalizing per-layer update magnitudes by constant shape-dependent factor. Notably, this learning rate transfer does not hold by default for Adam or Shampoo.
  • Figure 4: Iterate averaging enables the use of higher learning rates without degradation. Training with a higher learning rate creates a tradeoff between faster learning progress and increased parameter noise. By averaging previous iterates, parameter noise is lessened, and we can surpass the tradeoff to reveal a stronger optimal learning rate.
  • Figure 5: Optimizers are evaluated over 10k gradient steps, starting from three distinct checkpoints per objective. We design this setting to test robustness across objectives and across stages of training. As shown above for the LLM case, SPlus consistently reaches the same validation performance as Adam within a smaller fraction of gradient steps (dotted line).
  • ...and 2 more figures