Table of Contents
Fetching ...

The Implicit Bias of Steepest Descent with Mini-batch Stochastic Gradient

Jichu Li, Xuan Tang, Difan Zou

TL;DR

This work develops a unified theory for the implicit bias of mini-batch stochastic steepest descent in multi-class classification, revealing how batch size, momentum, and variance reduction interact with norm-induced geometries to determine the limiting max-margin solution. By framing steepest descent under general norms, it recovers Normalized-SGD, SignSGD, Muon, and their variants, and provides fully explicit, dimension-free convergence rates. Key findings show that without momentum, large batches are required to align with the full-batch max-margin; momentum enables small-batch convergence through a batch-momentum trade-off; variance reduction can recover the exact full-batch bias for any batch size at the cost of slower rates; a carefully constructed batch-size-one setting can yield a fundamentally different implicit bias. Collectively, these results clarify when stochastic optimization mimics full-batch behavior and guide practical design choices for SGD-like optimizers in large-scale training.

Abstract

A variety of widely used optimization methods like SignSGD and Muon can be interpreted as instances of steepest descent under different norm-induced geometries. In this work, we study the implicit bias of mini-batch stochastic steepest descent in multi-class classification, characterizing how batch size, momentum, and variance reduction shape the limiting max-margin behavior and convergence rates under general entry-wise and Schatten-$p$ norms. We show that without momentum, convergence only occurs with large batches, yielding a batch-dependent margin gap but the full-batch convergence rate. In contrast, momentum enables small-batch convergence through a batch-momentum trade-off, though it slows convergence. This approach provides fully explicit, dimension-free rates that improve upon prior results. Moreover, we prove that variance reduction can recover the exact full-batch implicit bias for any batch size, albeit at a slower convergence rate. Finally, we further investigate the batch-size-one steepest descent without momentum, and reveal its convergence to a fundamentally different bias via a concrete data example, which reveals a key limitation of purely stochastic updates. Overall, our unified analysis clarifies when stochastic optimization aligns with full-batch behavior, and paves the way for perform deeper explorations of the training behavior of stochastic gradient steepest descent algorithms.

The Implicit Bias of Steepest Descent with Mini-batch Stochastic Gradient

TL;DR

This work develops a unified theory for the implicit bias of mini-batch stochastic steepest descent in multi-class classification, revealing how batch size, momentum, and variance reduction interact with norm-induced geometries to determine the limiting max-margin solution. By framing steepest descent under general norms, it recovers Normalized-SGD, SignSGD, Muon, and their variants, and provides fully explicit, dimension-free convergence rates. Key findings show that without momentum, large batches are required to align with the full-batch max-margin; momentum enables small-batch convergence through a batch-momentum trade-off; variance reduction can recover the exact full-batch bias for any batch size at the cost of slower rates; a carefully constructed batch-size-one setting can yield a fundamentally different implicit bias. Collectively, these results clarify when stochastic optimization mimics full-batch behavior and guide practical design choices for SGD-like optimizers in large-scale training.

Abstract

A variety of widely used optimization methods like SignSGD and Muon can be interpreted as instances of steepest descent under different norm-induced geometries. In this work, we study the implicit bias of mini-batch stochastic steepest descent in multi-class classification, characterizing how batch size, momentum, and variance reduction shape the limiting max-margin behavior and convergence rates under general entry-wise and Schatten- norms. We show that without momentum, convergence only occurs with large batches, yielding a batch-dependent margin gap but the full-batch convergence rate. In contrast, momentum enables small-batch convergence through a batch-momentum trade-off, though it slows convergence. This approach provides fully explicit, dimension-free rates that improve upon prior results. Moreover, we prove that variance reduction can recover the exact full-batch implicit bias for any batch size, albeit at a slower convergence rate. Finally, we further investigate the batch-size-one steepest descent without momentum, and reveal its convergence to a fundamentally different bias via a concrete data example, which reveals a key limitation of purely stochastic updates. Overall, our unified analysis clarifies when stochastic optimization aligns with full-batch behavior, and paves the way for perform deeper explorations of the training behavior of stochastic gradient steepest descent algorithms.
Paper Structure (49 sections, 49 theorems, 335 equations, 4 figures, 1 algorithm)

This paper contains 49 sections, 49 theorems, 335 equations, 4 figures, 1 algorithm.

Key Result

Theorem 4.1

Suppose Assumptions ass:sep, ass:data_bound, and ass:learning_rate_1 hold. Assume the batch size $b$ satisfies the large batch condition: $\rho \coloneqq \gamma - 4(\frac{n}{b}-1)R > 0$$(b>\frac{4Rn}{\gamma+4R})$. There exist $t_2=t_2(n,b,\gamma,\mathbf{W}_0,R)$ that for all $t > t_2$, the margin ga

Figures (4)

  • Figure 1: Empirical validation of the implicit bias of steepest descent under the $\ell_2$ norm. The experiments show that mini-batch sampling breaks the full-batch $\ell_2$ implicit bias while large momentum or variance reduction restores convergence to the max-margin solution. (a) N-SGD with full-batch size $b=200$. (b) N-SGD with mini-batch size $b=20$. (c) N-MSGD with $\beta_1=0.5$ and $b=200$. (d) N-MSGD with $\beta_1=0.5$ and $b=20$. (e) N-MSGD with $\beta_1=0.99$ and $b=200$. (f) N-MSGD with $\beta_1=0.99$ and $b=20$. (g) VR-N-SGD with $b=20$. (h) VR-N-MSGD with $\beta_1=0.5$ and $b=20$. (i) VR-N-MSGD with $\beta_1=0.99$ and $b=20$.
  • Figure 2:
  • Figure 3: Empirical validation of the implicit bias of steepest descent under the $\ell_2$ norm. (a) SignSGD with full-batch size $b=200$. (b) SignSGD with mini-batch size $b=20$. (c) Signum with momentum $\beta_1=0.5$ and full-batch size $b=200$. (d) Signum with momentum $\beta_1=0.5$ and mini-batch size $b=20$. (e) Signum with momentum $\beta_1=0.99$ and full-batch size $b=200$. (f) Signum with momentum $\beta_1=0.99$ and mini-batch size $b=20$. (g) VR-SignSGD with mini-batch size $b=20$. (h) VR-Signum with momentum $\beta_1=0.5$ and mini-batch size $b=20$. (i) VR-Signum with momentum $\beta_1=0.99$ and mini-batch size $b=20$.
  • Figure 4: Empirical validation of the implicit bias of steepest descent under the $\ell_2$ norm. (a) Spectral-SGD with full-batch size $b=200$. (b) Spectral-SGD with mini-batch size $b=20$. (c) Muon with momentum $\beta_1=0.5$ and full-batch size $b=200$. (d) Muon with momentum $\beta_1=0.5$ and mini-batch size $b=20$. (e) Muon with momentum $\beta_1=0.99$ and full-batch size $b=200$. (f) Muon with momentum $\beta_1=0.99$ and mini-batch size $b=20$. (g) VR-Spectral-SGD with mini-batch size $b=20$. (h) VR-Muon with momentum $\beta_1=0.5$ and mini-batch size $b=20$. (i) VR-Muon with momentum $\beta_1=0.99$ and mini-batch size $b=20$.

Theorems & Definitions (93)

  • Theorem 4.1: Margin Convergence of Stochastic Steepest Descent without Momentum
  • Corollary 4.2
  • Theorem 4.3: Margin Convergence of Stochastic Steepest Descent with Momentum
  • Remark 4.4
  • Corollary 4.5
  • Theorem 4.6: Margin Convergence of VR-Stochastic Steepest Descent
  • Corollary 4.7
  • Definition 4.8: Bias Directions
  • Theorem 4.9: Implicit Bias of Per-sample SignSGD and Per-sample Normalized-SGD
  • Lemma C.1
  • ...and 83 more