A Stable Whitening Optimizer for Efficient Neural Network Training
Kevin Frans, Sergey Levine, Pieter Abbeel
TL;DR
A central problem in neural network optimization is achieving fast, stable progress with second-order-like updates at scale. The paper presents SPlus, a stable whitening optimizer that builds on Shampoo by incorporating instant-sign normalization, shape-aware symmetric scaling, and iterate averaging to address divergence, width-related learning-rate transfer, and parameter noise. Empirical results on Transformer-based tasks across language modeling, diffusion, and image classification show SPlus consistently outperforms or matches strong baselines (including Adam) in both gradient steps and wallclock time, with a practical implementation that minimizes extra passes. The work offers a scalable, reproducible protocol and release of code, contributing a robust alternative for fast, stable training of large neural networks. The improvements have practical impact by enabling faster convergence without extensive hyperparameter tuning, potentially accelerating training of large-scale models.
Abstract
In this work, we take an experimentally grounded look at neural network optimization. Building on the Shampoo family of algorithms, we identify and alleviate three key issues, resulting in the proposed SPlus method. First, we find that naive Shampoo is prone to divergence when matrix-inverses are cached for long periods. We introduce an alternate bounded update combining a historical eigenbasis with instantaneous normalization, resulting in across-the-board stability and significantly lower computational requirements. Second, we adapt a shape-aware scaling to enable learning rate transfer across network width. Third, we find that high learning rates result in large parameter noise, and propose a simple iterate-averaging scheme which unblocks faster learning. To properly confirm these findings, we introduce a pointed Transformer training benchmark, considering three objectives (language modelling, image classification, and diffusion modelling) across different stages of training. On average, SPlus is able to reach the validation performance of Adam within 44-58% of the gradient steps and 62-83% of the wallclock time.
