The Implicit Bias of Steepest Descent with Mini-batch Stochastic Gradient
Jichu Li, Xuan Tang, Difan Zou
TL;DR
This work develops a unified theory for the implicit bias of mini-batch stochastic steepest descent in multi-class classification, revealing how batch size, momentum, and variance reduction interact with norm-induced geometries to determine the limiting max-margin solution. By framing steepest descent under general norms, it recovers Normalized-SGD, SignSGD, Muon, and their variants, and provides fully explicit, dimension-free convergence rates. Key findings show that without momentum, large batches are required to align with the full-batch max-margin; momentum enables small-batch convergence through a batch-momentum trade-off; variance reduction can recover the exact full-batch bias for any batch size at the cost of slower rates; a carefully constructed batch-size-one setting can yield a fundamentally different implicit bias. Collectively, these results clarify when stochastic optimization mimics full-batch behavior and guide practical design choices for SGD-like optimizers in large-scale training.
Abstract
A variety of widely used optimization methods like SignSGD and Muon can be interpreted as instances of steepest descent under different norm-induced geometries. In this work, we study the implicit bias of mini-batch stochastic steepest descent in multi-class classification, characterizing how batch size, momentum, and variance reduction shape the limiting max-margin behavior and convergence rates under general entry-wise and Schatten-$p$ norms. We show that without momentum, convergence only occurs with large batches, yielding a batch-dependent margin gap but the full-batch convergence rate. In contrast, momentum enables small-batch convergence through a batch-momentum trade-off, though it slows convergence. This approach provides fully explicit, dimension-free rates that improve upon prior results. Moreover, we prove that variance reduction can recover the exact full-batch implicit bias for any batch size, albeit at a slower convergence rate. Finally, we further investigate the batch-size-one steepest descent without momentum, and reveal its convergence to a fundamentally different bias via a concrete data example, which reveals a key limitation of purely stochastic updates. Overall, our unified analysis clarifies when stochastic optimization aligns with full-batch behavior, and paves the way for perform deeper explorations of the training behavior of stochastic gradient steepest descent algorithms.
