Identifying Spurious Biases Early in Training through the Lens of Simplicity Bias
Yu Yang, Eric Gan, Gintare Karolina Dziugaite, Baharan Mirzasoleiman
TL;DR
This work investigates how gradient-descent optimization biases neural networks toward simple, spuriously correlated features, which degrades worst-group generalization. It provides a theoretical analysis of early training dynamics showing spurious features can dominate outputs in the initial iterations and that majority/minority groups become separable early, especially when spurious features have strong correlation and magnitude. Building on these insights, it introduces SPARE (SePArate early and REsample), a lightweight, clustering-based approach that identifies groups in the first few epochs and applies importance sampling to rebalance training without extra data or heavy tuning. Empirically, SPARE achieves state-of-the-art worst-group accuracy across benchmarks (CMNIST, Waterbirds, CelebA, UrbanCars) and scales to large datasets like Restricted ImageNet, while being significantly faster and easier to tune than existing group-inference methods.
Abstract
Neural networks trained with (stochastic) gradient descent have an inductive bias towards learning simpler solutions. This makes them highly prone to learning spurious correlations in the training data, that may not hold at test time. In this work, we provide the first theoretical analysis of the effect of simplicity bias on learning spurious correlations. Notably, we show that examples with spurious features are provably separable based on the model's output early in training. We further illustrate that if spurious features have a small enough noise-to-signal ratio, the network's output on the majority of examples is almost exclusively determined by the spurious features, leading to poor worst-group test accuracy. Finally, we propose SPARE, which identifies spurious correlations early in training and utilizes importance sampling to alleviate their effect. Empirically, we demonstrate that SPARE outperforms state-of-the-art methods by up to 21.1% in worst-group accuracy, while being up to 12x faster. We also show that SPARE is a highly effective but lightweight method to discover spurious correlations.
