Table of Contents
Fetching ...

Understanding the Generalization of Stochastic Gradient Adam in Learning Neural Networks

Xuan Tang, Han Zhang, Yuan Cao, Difan Zou

TL;DR

This paper analyzes how batch size affects the generalization of stochastic gradient Adam and AdamW in two-layer over-parameterized CNNs, showing large batches lead to memorization of noise and poor test performance, while mini-batch training with weight decay achieves near-zero test error. It provides rigorous results for both Adam and AdamW, revealing that Adam’s adaptive normalization imposes tighter upper bounds on effective weight decay than AdamW, making tuning more delicate. The authors connect stochastic Adam dynamics to SignSGD in the small-learning-rate regime and present a two-stage training narrative: initial feature learning followed by regularization-dominated convergence, with experimental validation on CIFAR-10 and ImageNet-1K subsets. The findings illuminate the critical roles of batch size and explicit weight decay in achieving generalizable solutions with adaptive optimizers, and offer practical guidance on tuning in real-world training scenarios.

Abstract

Adam is a popular and widely used adaptive gradient method in deep learning, which has also received tremendous focus in theoretical research. However, most existing theoretical work primarily analyzes its full-batch version, which differs fundamentally from the stochastic variant used in practice. Unlike SGD, stochastic Adam does not converge to its full-batch counterpart even with infinitesimal learning rates. We present the first theoretical characterization of how batch size affects Adam's generalization, analyzing two-layer over-parameterized CNNs on image data. Our results reveal that while both Adam and AdamW with proper weight decay $λ$ converge to poor test error solutions, their mini-batch variants can achieve near-zero test error. We further prove Adam has a strictly smaller effective weight decay bound than AdamW, theoretically explaining why Adam requires more sensitive $λ$ tuning. Extensive experiments validate our findings, demonstrating the critical role of batch size and weight decay in Adam's generalization performance.

Understanding the Generalization of Stochastic Gradient Adam in Learning Neural Networks

TL;DR

This paper analyzes how batch size affects the generalization of stochastic gradient Adam and AdamW in two-layer over-parameterized CNNs, showing large batches lead to memorization of noise and poor test performance, while mini-batch training with weight decay achieves near-zero test error. It provides rigorous results for both Adam and AdamW, revealing that Adam’s adaptive normalization imposes tighter upper bounds on effective weight decay than AdamW, making tuning more delicate. The authors connect stochastic Adam dynamics to SignSGD in the small-learning-rate regime and present a two-stage training narrative: initial feature learning followed by regularization-dominated convergence, with experimental validation on CIFAR-10 and ImageNet-1K subsets. The findings illuminate the critical roles of batch size and explicit weight decay in achieving generalizable solutions with adaptive optimizers, and offer practical guidance on tuning in real-world training scenarios.

Abstract

Adam is a popular and widely used adaptive gradient method in deep learning, which has also received tremendous focus in theoretical research. However, most existing theoretical work primarily analyzes its full-batch version, which differs fundamentally from the stochastic variant used in practice. Unlike SGD, stochastic Adam does not converge to its full-batch counterpart even with infinitesimal learning rates. We present the first theoretical characterization of how batch size affects Adam's generalization, analyzing two-layer over-parameterized CNNs on image data. Our results reveal that while both Adam and AdamW with proper weight decay converge to poor test error solutions, their mini-batch variants can achieve near-zero test error. We further prove Adam has a strictly smaller effective weight decay bound than AdamW, theoretically explaining why Adam requires more sensitive tuning. Extensive experiments validate our findings, demonstrating the critical role of batch size and weight decay in Adam's generalization performance.

Paper Structure

This paper contains 50 sections, 48 theorems, 246 equations, 12 figures, 2 tables.

Key Result

Theorem 4.1

Suppose $\eta=\frac{1}{\mathrm{poly}(n)}$ and $\lambda$ satisfies $0 < \lambda=o(\frac{\sigma_0^{q-2}\sigma_p}{n})$, we train our CNN model in Definition def:model on loss function eq:adam_loss for $T=\frac{\mathrm{poly}(n)}{\eta}$ epochs using Adam eq:adam_upd with batch size $B$ satisfies $\frac{n

Figures (12)

  • Figure 1: Test error vs. batch size for VGG16 and ResNet18 on CIFAR-10.
  • Figure 2: Test error vs. weight decay (batch size = 16), comparing Adam and AdamW on each model.
  • Figure 3: Feature learning and noise memorization of Adam in the training.
  • Figure 4: Feature learning and noise memorization of AdamW in the training.
  • Figure 5: Training error and test error over epochs of Adam training with $\lambda=0.05$.
  • ...and 7 more figures

Theorems & Definitions (75)

  • Definition 3.1
  • Definition 3.2
  • Theorem 4.1: Large-batch Adam
  • Theorem 4.2: Mini-batch Adam
  • Corollary 4.3: Effective weight decay in Adam
  • Theorem 4.4: Large-batch AdamW
  • Theorem 4.5: Mini-batch AdamW
  • Corollary 4.6
  • Lemma 5.1
  • Lemma 5.2: Stage I
  • ...and 65 more