Table of Contents
Fetching ...

The Implicit Bias of Adam and Muon on Smooth Homogeneous Neural Networks

Eitan Gronich, Gal Vardi

TL;DR

This paper analyzes the implicit bias of momentum-based optimizers on smooth homogeneous models by showing that the direction of normalized parameter iterates $\frac{\bm{\theta}_t}{\|\bm{\theta}_t\|}$ converges to a KKT point of a max-margin problem under decaying learning rates with $\int_0^\infty \eta(t) dt = \infty$. It formalizes a unifying framework of approximate steepest descent that applies to normalized steepest descent, momentum-based methods (Muon, Muon-Signum, Muon-Adam), and Adam without the stability constant, deriving margin-maximization biases with respect to spectral-norm, composite max-norm, and $\ell_\infty$ margins. The results extend prior linear-model and homogeneous-network analyses to broader smooth homogeneous models and provide corollaries for composite optimizers, which is validated by MNIST experiments showing optimizer-dependent margin identities. The work also discusses non-smooth extensions, directional convergence assumptions, and potential implications for generalization, robustness, and training dynamics of large-scale models. Overall, the paper offers a cohesive theory linking optimizer dynamics, margin maximization, and implicit bias across a spectrum of first-order methods.

Abstract

We study the implicit bias of momentum-based optimizers on homogeneous models. We first extend existing results on the implicit bias of steepest descent in homogeneous models to normalized steepest descent with an optional learning rate schedule. We then show that for smooth homogeneous models, momentum steepest descent algorithms like Muon (spectral norm), MomentumGD ($\ell_2$ norm), and Signum ($\ell_\infty$ norm) are approximate steepest descent trajectories under a decaying learning rate schedule, proving that these algorithms too have a bias towards KKT points of the corresponding margin maximization problem. We extend the analysis to Adam (without the stability constant), which maximizes the $\ell_\infty$ margin, and to Muon-Signum and Muon-Adam, which maximize a hybrid norm. Our experiments corroborate the theory and show that the identity of the margin maximized depends on the choice of optimizer. Overall, our results extend earlier lines of work on steepest descent in homogeneous models and momentum-based optimizers in linear models.

The Implicit Bias of Adam and Muon on Smooth Homogeneous Neural Networks

TL;DR

This paper analyzes the implicit bias of momentum-based optimizers on smooth homogeneous models by showing that the direction of normalized parameter iterates converges to a KKT point of a max-margin problem under decaying learning rates with . It formalizes a unifying framework of approximate steepest descent that applies to normalized steepest descent, momentum-based methods (Muon, Muon-Signum, Muon-Adam), and Adam without the stability constant, deriving margin-maximization biases with respect to spectral-norm, composite max-norm, and margins. The results extend prior linear-model and homogeneous-network analyses to broader smooth homogeneous models and provide corollaries for composite optimizers, which is validated by MNIST experiments showing optimizer-dependent margin identities. The work also discusses non-smooth extensions, directional convergence assumptions, and potential implications for generalization, robustness, and training dynamics of large-scale models. Overall, the paper offers a cohesive theory linking optimizer dynamics, margin maximization, and implicit bias across a spectrum of first-order methods.

Abstract

We study the implicit bias of momentum-based optimizers on homogeneous models. We first extend existing results on the implicit bias of steepest descent in homogeneous models to normalized steepest descent with an optional learning rate schedule. We then show that for smooth homogeneous models, momentum steepest descent algorithms like Muon (spectral norm), MomentumGD ( norm), and Signum ( norm) are approximate steepest descent trajectories under a decaying learning rate schedule, proving that these algorithms too have a bias towards KKT points of the corresponding margin maximization problem. We extend the analysis to Adam (without the stability constant), which maximizes the margin, and to Muon-Signum and Muon-Adam, which maximize a hybrid norm. Our experiments corroborate the theory and show that the identity of the margin maximized depends on the choice of optimizer. Overall, our results extend earlier lines of work on steepest descent in homogeneous models and momentum-based optimizers in linear models.
Paper Structure (34 sections, 42 theorems, 197 equations, 2 figures, 1 table)

This paper contains 34 sections, 42 theorems, 197 equations, 2 figures, 1 table.

Key Result

Theorem 3.1

Let $\bm{\theta}_t$ be a trajectory of normalized steepest descent with respect to a norm $\left \lVert \cdot \right \rVert$ (Equation eq:norm_steepest_descent). Under Assumptions model_ass:lipschitz_C1, model_ass:homogeneous_0, real_ass, lr_ass:lr_nsd, the soft margin $\widetilde{\gamma}(\bm{\theta

Figures (2)

  • Figure 1: (a) Margin values vs. loss for different optimizers. A lighter/darker color signifies the squared-ReLU / ReLU activations respectively. Dotted lines represent optimizers with momentum disabled. Lines are mean values over 10 random seeds, while filled areas are 95% confidence intervals. (b) Cosine similarity to last iterate $\left\langle \frac{\bm{\theta}_t}{\left \lVert \bm{\theta}_t \right \rVert_2},\frac{\bm{\theta}_{\text{last}}}{\left \lVert \bm{\theta}_{\text{last}} \right \rVert_2} \right\rangle$, plotted on a normalized linear time scale.
  • Figure 2: Margin values vs. loss for different optimizers. A lighter/darker color signifies the squared-ReLU / ReLU activations respectively. Lines are mean values over 10 random seeds, while filled areas are 95% confidence intervals.

Theorems & Definitions (86)

  • Definition 2.1
  • Theorem 3.1
  • Theorem 3.2
  • Theorem 3.3
  • Corollary 3.4
  • Corollary 3.5
  • Theorem 3.6
  • Theorem 3.7
  • Definition 5.1: Approximate Steepest Descent
  • Theorem A.1: Theorem 2.3.9 and 2.3.10 in Cla90
  • ...and 76 more