The Rich and the Simple: On the Implicit Bias of Adam and SGD
Bhavya Vasudeva, Jung Whan Lee, Vatsal Sharan, Mahdi Soltanolkotabi
TL;DR
The paper investigates how Adam's implicit bias differs from SGD in training two-layer ReLU networks on Gaussian mixture data. By deriving population gradients and analyzing both gradient flow and Adam-like updates, it shows that SGD tends toward linear, simple boundaries while Adam learns nonlinear boundaries closer to Bayes optimal predictions, yielding better generalization under certain distribution shifts. Extensive experiments across synthetic data, MNIST-based spurious-feature tasks, and subgroup-robustness benchmarks corroborate that Adam's richer feature learning improves worst-group accuracy and core-feature decoding, suggesting practical advantages for handling spurious correlations. The work provides a principled contrast between optimization schemes, guides expectations about generalization in the presence of spurious features, and points to future work on broader architectures and regularization effects.
Abstract
Adam is the de facto optimization algorithm for several deep learning applications, but an understanding of its implicit bias and how it differs from other algorithms, particularly standard first-order methods such as (stochastic) gradient descent (GD), remains limited. In practice, neural networks (NNs) trained with SGD are known to exhibit simplicity bias -- a tendency to find simple solutions. In contrast, we show that Adam is more resistant to such simplicity bias. First, we investigate the differences in the implicit biases of Adam and GD when training two-layer ReLU NNs on a binary classification task with Gaussian data. We find that GD exhibits a simplicity bias, resulting in a linear decision boundary with a suboptimal margin, whereas Adam leads to much richer and more diverse features, producing a nonlinear boundary that is closer to the Bayes' optimal predictor. This richer decision boundary also allows Adam to achieve higher test accuracy both in-distribution and under certain distribution shifts. We theoretically prove these results by analyzing the population gradients. Next, to corroborate our theoretical findings, we present extensive empirical results showing that this property of Adam leads to superior generalization across various datasets with spurious correlations where NNs trained with SGD are known to show simplicity bias and do not generalize well under certain distributional shifts.
