Table of Contents
Fetching ...

On the Implicit Bias of Adam

Matias D. Cattaneo, Jason M. Klusowski, Boris Shigida

TL;DR

This paper uses backward error analysis to derive a global, second-order in step size $h$ ODE approximation for Adam (and RMSProp) and its mini-batch and full-batch variants. It shows that Adam typically anti-penalizes the perturbed gradient one-norm $\|\nabla E(\boldsymbol{\theta})\|_{1,\varepsilon}$ when $\sqrt{\varepsilon}$ is small and $\rho$ exceeds $\beta$, signaling implicit anti-regularization that can worsen generalization, while other hyperparameter regimes recover GD-like regularization. Theoretical results are complemented by numerical experiments on vision architectures (ResNets, CNNs, ViTs) and standard datasets, which corroborate the predicted anti-regularization effects and link them to generalization performance. Overall, the work provides a principled framework for understanding the implicit bias of adaptive optimizers and motivates further study of their generalization behavior across architectures.

Abstract

In previous literature, backward error analysis was used to find ordinary differential equations (ODEs) approximating the gradient descent trajectory. It was found that finite step sizes implicitly regularize solutions because terms appearing in the ODEs penalize the two-norm of the loss gradients. We prove that the existence of similar implicit regularization in RMSProp and Adam depends on their hyperparameters and the training stage, but with a different "norm" involved: the corresponding ODE terms either penalize the (perturbed) one-norm of the loss gradients or, conversely, impede its reduction (the latter case being typical). We also conduct numerical experiments and discuss how the proven facts can influence generalization.

On the Implicit Bias of Adam

TL;DR

This paper uses backward error analysis to derive a global, second-order in step size ODE approximation for Adam (and RMSProp) and its mini-batch and full-batch variants. It shows that Adam typically anti-penalizes the perturbed gradient one-norm when is small and exceeds , signaling implicit anti-regularization that can worsen generalization, while other hyperparameter regimes recover GD-like regularization. Theoretical results are complemented by numerical experiments on vision architectures (ResNets, CNNs, ViTs) and standard datasets, which corroborate the predicted anti-regularization effects and link them to generalization performance. Overall, the work provides a principled framework for understanding the implicit bias of adaptive optimizers and motivates further study of their generalization behavior across architectures.

Abstract

In previous literature, backward error analysis was used to find ordinary differential equations (ODEs) approximating the gradient descent trajectory. It was found that finite step sizes implicitly regularize solutions because terms appearing in the ODEs penalize the two-norm of the loss gradients. We prove that the existence of similar implicit regularization in RMSProp and Adam depends on their hyperparameters and the training stage, but with a different "norm" involved: the corresponding ODE terms either penalize the (perturbed) one-norm of the loss gradients or, conversely, impede its reduction (the latter case being typical). We also conduct numerical experiments and discuss how the proven facts can influence generalization.
Paper Structure (11 sections, 1 theorem, 20 equations, 6 figures, 1 table)

This paper contains 11 sections, 1 theorem, 20 equations, 6 figures, 1 table.

Key Result

Theorem 3.1

Assume eq:bounded-smoothness-assumption holds. Let $\{\boldsymbol{\theta}^{(n)}\}$ be iterations of Adam as defined in def:adam, $\tilde{\boldsymbol{\theta}}(t)$ be the continuous solution to the piecewise ODE for $t \in [n h, (n + 1) h]$ with the initial condition $\tilde{\boldsymbol{\theta}}(0) = \boldsymbol{\theta}^{(0)}$, where Then, for any fixed positive time horizon $T > 0$ there exists a

Figures (6)

  • Figure 1: To illustrate what term is being implicitly penalized in the simple case $p = 1$, we plot the graphs of $x \mapsto F(x) := \frac{h}{2} \int_0^x \bigl\{ \frac{1 + \beta}{1 - \beta} - \frac{1 + \rho}{1 - \rho} + \frac{1 + \rho}{1 - \rho} \cdot \frac{\varepsilon}{y^2 + \varepsilon} \bigr\}\,\mathrm{d}\sqrt{\varepsilon + y^2}$ with $\beta = 0.95$. In this case, the correction term in \ref{['eq:bias']} is itself the gradient of the function $F(E'(\theta))$, where $E'$ is the derivative (=gradient) of the loss: specifically, $\text{correction} = \frac{\mathrm{d}}{\mathrm{d} \theta} F(E'(\theta))$. Hence, Adam's iteration penalizes $F(E'(\theta))$. If $\varepsilon$ is small and $\rho > \beta$, the negative one-norm of the gradient is penalized (leftmost picture, highest values of $\rho$); in other words, the one-norm is anti-penalized.
  • Figure 2: Left: increasing $\beta$ moves the trajectory of Adam towards the regions with smaller one-norm of the gradient (if $\varepsilon$ is sufficiently small); increasing $\rho$ does the opposite. Right: increasing the learning rate moves the Adam trajectory towards the regions with smaller one-norm of the gradient if $\beta$ is significantly larger than $\rho$ and does the opposite if $\rho$ is larger than $\beta$. The cross denotes the limit point of gradient one-norm minimizers on the level sets $4 \theta_1 \theta_2 - 3 = c$. The minimizers are drawn with a dashed line. All Adam trajectories start at $(2.8, 3.5)$.
  • Figure 3: $\|\boldsymbol{\theta}^{(n)} - \tilde{\boldsymbol{\theta}}(t_n)\|_{\infty}$ for a MLP trained with full-batch Adam on truncated MNIST, where $\tilde{\boldsymbol{\theta}}(t_n)$ is either first (signGD perturbed by $\varepsilon$) or second order approximation to Adam; $\beta = 0.9$, $\rho = 0.95$, $\varepsilon = 10^{-6}$. Precise definitions are provided in \ref{['sec:numerical-experiments']}, specifically \ref{['eq:bea-closeness-precise-iterations']}.
  • Figure 4: Resnet-50 on CIFAR-10 trained with full-batch Adam, $\varepsilon = 10^{-8}$, $\beta = 0.99$. As $\rho$ increases, the norm rises and the test accuracy falls. We train longer than necessary for near-perfect classification on the train dataset (at least 2-3 thousand epochs), and the test accuracies plotted here are maximal. The perturbed norms are also maximal after excluding the initial training period (i. e., the plotted "norms" are at peaks of the "hills" described in \ref{['sec:numerical']}). All results are averaged across five runs with different initialization seeds. Additional evidence and more details are provided in \ref{['sec:numerical-experiments']}.
  • Figure 5: Resnet-50 on CIFAR-10 trained with full-batch Adam, $\rho = 0.999$, $\varepsilon = 10^{-8}$. The perturbed one-norm falls as $\beta$ increases, and the test accuracy rises. Both metrics are calculated as in \ref{['fig:resnet-50-rho-increase']}. All results are averaged across three runs with different initialization seeds.
  • ...and 1 more figures

Theorems & Definitions (7)

  • Definition 1.1
  • Remark 1.2
  • Example 2.1: Backward Error Analysis for GD with Heavy-ball Momentum
  • Theorem 3.1
  • Remark 3.2
  • proof : Derivation sketch
  • Definition 2.1