Table of Contents
Fetching ...

Implicit Bias of AdamW: $\ell_\infty$ Norm Constrained Optimization

Shuo Xie, Zhiyuan Li

TL;DR

The paper investigates why AdamW often generalizes better than Adam with ℓ2 regularization by proposing that AdamW implicitly performs optimization under an ℓ∞-norm constraint. It establishes that normalized steepest descent with weight decay converges to KKT points of a norm-ball constrained objective, and shows that AdamW exhibits the same implicit bias under suitable learning-rate schedules and momentum conditions. A key technical result is a tight upper bound on the average size of Adam-like updates, which underpins the convergence to constrained optima; the theory is complemented by experiments on a PTB language-modeling task and a synthetic problem, illustrating the practical relevance of ℓ∞ geometry. Overall, the work provides a first principled explanation for AdamW’s empirical advantages and highlights the role of geometry in adaptive optimization.

Abstract

Adam with decoupled weight decay, also known as AdamW, is widely acclaimed for its superior performance in language modeling tasks, surpassing Adam with $\ell_2$ regularization in terms of generalization and optimization. However, this advantage is not theoretically well-understood. One challenge here is that though intuitively Adam with $\ell_2$ regularization optimizes the $\ell_2$ regularized loss, it is not clear if AdamW optimizes a specific objective. In this work, we make progress toward understanding the benefit of AdamW by showing that it implicitly performs constrained optimization. More concretely, we show in the full-batch setting, if AdamW converges with any non-increasing learning rate schedule whose partial sum diverges, it must converge to a KKT point of the original loss under the constraint that the $\ell_\infty$ norm of the parameter is bounded by the inverse of the weight decay factor. This result is built on the observation that Adam can be viewed as a smoothed version of SignGD, which is the normalized steepest descent with respect to $\ell_\infty$ norm, and a surprising connection between normalized steepest descent with weight decay and Frank-Wolfe.

Implicit Bias of AdamW: $\ell_\infty$ Norm Constrained Optimization

TL;DR

The paper investigates why AdamW often generalizes better than Adam with ℓ2 regularization by proposing that AdamW implicitly performs optimization under an ℓ∞-norm constraint. It establishes that normalized steepest descent with weight decay converges to KKT points of a norm-ball constrained objective, and shows that AdamW exhibits the same implicit bias under suitable learning-rate schedules and momentum conditions. A key technical result is a tight upper bound on the average size of Adam-like updates, which underpins the convergence to constrained optima; the theory is complemented by experiments on a PTB language-modeling task and a synthetic problem, illustrating the practical relevance of ℓ∞ geometry. Overall, the work provides a first principled explanation for AdamW’s empirical advantages and highlights the role of geometry in adaptive optimization.

Abstract

Adam with decoupled weight decay, also known as AdamW, is widely acclaimed for its superior performance in language modeling tasks, surpassing Adam with regularization in terms of generalization and optimization. However, this advantage is not theoretically well-understood. One challenge here is that though intuitively Adam with regularization optimizes the regularized loss, it is not clear if AdamW optimizes a specific objective. In this work, we make progress toward understanding the benefit of AdamW by showing that it implicitly performs constrained optimization. More concretely, we show in the full-batch setting, if AdamW converges with any non-increasing learning rate schedule whose partial sum diverges, it must converge to a KKT point of the original loss under the constraint that the norm of the parameter is bounded by the inverse of the weight decay factor. This result is built on the observation that Adam can be viewed as a smoothed version of SignGD, which is the normalized steepest descent with respect to norm, and a surprising connection between normalized steepest descent with weight decay and Frank-Wolfe.
Paper Structure (30 sections, 15 theorems, 33 equations, 6 figures, 2 algorithms)

This paper contains 30 sections, 15 theorems, 33 equations, 6 figures, 2 algorithms.

Key Result

Theorem 1.1

For any continuously differentiable function $L:\mathbb{R}^d\to\mathbb{R}$, $\beta_1 \leq \beta_2<1$, initialization ${\bm{x}}_0$ and non-increasing learning rate $\{\eta_t\}_{t=1}^\infty$ such that $\sum_{t=1}^\infty \eta_t = \infty$, if the iterates of $\mathtt{AdamW}$$\{{\bm{x}}_t\}_{t=0}^\infty$

Figures (6)

  • Figure 1: The $\ell_\infty$ norm of parameters during the training process of language modeling task on PTB. The complete results for $\mathtt{Adam}$ are in \ref{['fig:ptb_seed_0_full']}. As predicted by \ref{['lem:iterate_norm_bound']}, $\ell_\infty$ norm can be bounded by $\frac{1}{\lambda}$ when $\beta_1=\beta_2$ or $\lambda\eta \ll 1-\beta_2<1-\beta_1$. However, for the default setting $\beta_1=0.9$ and $\beta_2=0.999$, the $\ell_\infty$ norm of $\mathtt{AdamW}$ may not be bounded by $\frac{1}{\lambda}$ because $1-\beta_2<\lambda\eta<1-\beta_1$.
  • Figure 2: $\ell_\infty$ norm of parameters when the batch size is half of the entire training set. Our norm upper bound \ref{['lem:iterate_norm_bound']} still holds in this case.
  • Figure 3: For both $\ell_2$ and $\ell_\infty$ norm, we plot training loss of normalized steepest descent w. and w.o. weight decay and unnormalized steepest descent over the quadratic loss $g({\bm{x}}) = \sum_{i=1}^{100} \frac{({\bm{x}}_i-{\bm{x}}^*_i)^2}{i^2}$. When weight decay is turned on, it is set as $\frac{1}{\left\|{\bm{x}}^*\right\|}$ to preserve the optimal value even with the norm constraints \ref{['thm:any_rate_convergence']}. We find that $\ell_\infty$ norm always outperforms $\ell_2$ norm regardless of the usage of weight decay and irrespective of whether the steepest descent method is normalized. The usage of weight decay accelerates the optimization for both $\ell_\infty$ norm and $\ell_2$ norm.
  • Figure 4: $\ell_\infty$ norm of parameters for $\mathtt{Adam}$ and $\mathtt{AdamW}$ with different $\beta_1, \beta_2$ for seed $1$
  • Figure 5: $\ell_\infty$ norm of parameters for $\mathtt{Adam}$ and $\mathtt{AdamW}$ with different $\beta_1, \beta_2$ for seed $0$. The range of y-axis is extended to show the full result of $\mathtt{Adam}$.
  • ...and 1 more figures

Theorems & Definitions (25)

  • Theorem 1.1
  • Lemma 2.1
  • Lemma 2.2
  • Lemma 3.1
  • Theorem 3.2
  • Lemma 3.3: Descent Lemma for Smooth Convex Loss
  • Theorem 3.4
  • Theorem 3.5
  • Definition 3.6: KKT points
  • Theorem 3.7: Non-convex, KKT
  • ...and 15 more