Understanding the Generalization of Stochastic Gradient Adam in Learning Neural Networks
Xuan Tang, Han Zhang, Yuan Cao, Difan Zou
TL;DR
This paper analyzes how batch size affects the generalization of stochastic gradient Adam and AdamW in two-layer over-parameterized CNNs, showing large batches lead to memorization of noise and poor test performance, while mini-batch training with weight decay achieves near-zero test error. It provides rigorous results for both Adam and AdamW, revealing that Adam’s adaptive normalization imposes tighter upper bounds on effective weight decay than AdamW, making tuning more delicate. The authors connect stochastic Adam dynamics to SignSGD in the small-learning-rate regime and present a two-stage training narrative: initial feature learning followed by regularization-dominated convergence, with experimental validation on CIFAR-10 and ImageNet-1K subsets. The findings illuminate the critical roles of batch size and explicit weight decay in achieving generalizable solutions with adaptive optimizers, and offer practical guidance on tuning in real-world training scenarios.
Abstract
Adam is a popular and widely used adaptive gradient method in deep learning, which has also received tremendous focus in theoretical research. However, most existing theoretical work primarily analyzes its full-batch version, which differs fundamentally from the stochastic variant used in practice. Unlike SGD, stochastic Adam does not converge to its full-batch counterpart even with infinitesimal learning rates. We present the first theoretical characterization of how batch size affects Adam's generalization, analyzing two-layer over-parameterized CNNs on image data. Our results reveal that while both Adam and AdamW with proper weight decay $λ$ converge to poor test error solutions, their mini-batch variants can achieve near-zero test error. We further prove Adam has a strictly smaller effective weight decay bound than AdamW, theoretically explaining why Adam requires more sensitive $λ$ tuning. Extensive experiments validate our findings, demonstrating the critical role of batch size and weight decay in Adam's generalization performance.
