Table of Contents
Fetching ...

Identifying Spurious Biases Early in Training through the Lens of Simplicity Bias

Yu Yang, Eric Gan, Gintare Karolina Dziugaite, Baharan Mirzasoleiman

TL;DR

This work investigates how gradient-descent optimization biases neural networks toward simple, spuriously correlated features, which degrades worst-group generalization. It provides a theoretical analysis of early training dynamics showing spurious features can dominate outputs in the initial iterations and that majority/minority groups become separable early, especially when spurious features have strong correlation and magnitude. Building on these insights, it introduces SPARE (SePArate early and REsample), a lightweight, clustering-based approach that identifies groups in the first few epochs and applies importance sampling to rebalance training without extra data or heavy tuning. Empirically, SPARE achieves state-of-the-art worst-group accuracy across benchmarks (CMNIST, Waterbirds, CelebA, UrbanCars) and scales to large datasets like Restricted ImageNet, while being significantly faster and easier to tune than existing group-inference methods.

Abstract

Neural networks trained with (stochastic) gradient descent have an inductive bias towards learning simpler solutions. This makes them highly prone to learning spurious correlations in the training data, that may not hold at test time. In this work, we provide the first theoretical analysis of the effect of simplicity bias on learning spurious correlations. Notably, we show that examples with spurious features are provably separable based on the model's output early in training. We further illustrate that if spurious features have a small enough noise-to-signal ratio, the network's output on the majority of examples is almost exclusively determined by the spurious features, leading to poor worst-group test accuracy. Finally, we propose SPARE, which identifies spurious correlations early in training and utilizes importance sampling to alleviate their effect. Empirically, we demonstrate that SPARE outperforms state-of-the-art methods by up to 21.1% in worst-group accuracy, while being up to 12x faster. We also show that SPARE is a highly effective but lightweight method to discover spurious correlations.

Identifying Spurious Biases Early in Training through the Lens of Simplicity Bias

TL;DR

This work investigates how gradient-descent optimization biases neural networks toward simple, spuriously correlated features, which degrades worst-group generalization. It provides a theoretical analysis of early training dynamics showing spurious features can dominate outputs in the initial iterations and that majority/minority groups become separable early, especially when spurious features have strong correlation and magnitude. Building on these insights, it introduces SPARE (SePArate early and REsample), a lightweight, clustering-based approach that identifies groups in the first few epochs and applies importance sampling to rebalance training without extra data or heavy tuning. Empirically, SPARE achieves state-of-the-art worst-group accuracy across benchmarks (CMNIST, Waterbirds, CelebA, UrbanCars) and scales to large datasets like Restricted ImageNet, while being significantly faster and easier to tune than existing group-inference methods.

Abstract

Neural networks trained with (stochastic) gradient descent have an inductive bias towards learning simpler solutions. This makes them highly prone to learning spurious correlations in the training data, that may not hold at test time. In this work, we provide the first theoretical analysis of the effect of simplicity bias on learning spurious correlations. Notably, we show that examples with spurious features are provably separable based on the model's output early in training. We further illustrate that if spurious features have a small enough noise-to-signal ratio, the network's output on the majority of examples is almost exclusively determined by the spurious features, leading to poor worst-group test accuracy. Finally, we propose SPARE, which identifies spurious correlations early in training and utilizes importance sampling to alleviate their effect. Empirically, we demonstrate that SPARE outperforms state-of-the-art methods by up to 21.1% in worst-group accuracy, while being up to 12x faster. We also show that SPARE is a highly effective but lightweight method to discover spurious correlations.
Paper Structure (56 sections, 7 theorems, 52 equations, 9 figures, 10 tables, 1 algorithm)

This paper contains 56 sections, 7 theorems, 52 equations, 9 figures, 10 tables, 1 algorithm.

Key Result

Theorem 4.1

Let $\alpha \in (0,\frac{1}{4})$ be a fixed constant. Suppose the number of training samples $n$ and the network width $m$ satisfy $n \gtrsim d^{1+\alpha}$ and $m \gtrsim d^{1+\alpha}$. Let $n_{c}$ be the number of examples in class $c$, and $n_{c,s}\!=\!|g_{c,s}|$ be the size of group $g_{c,s}$ wit where $c'\!=\mathcal{C}\!\setminus\! c$, and $\zeta$ is the expected gradient of activation functio

Figures (9)

  • Figure 1: Colored MNIST as an example of datasets containing spurious correlations. Each digit is a class; the majority of digits in a class have a particular color, and the remaining digits are in other colors. Models trained with ERM learn to rely on spurious features (colors) instead of the core feature (digits) and thus do not perform well on groups of examples where the spurious correlation does not hold.
  • Figure 2: Training LeNet-5 on Colored MNIST. Top: Up to epoch 2, the network output is almost exclusively indicated by the color red (spurious feature in the majority group). Bottom: Majority and minority groups are separable based on the network output, e.g. via clustering. Minority groups that have a spurious feature in majority groups of other classes (yellow, purple, blue, green) are also separable from each other. Similar results on Waterbirds are shown in \ref{['fig:rebuttal-pred-diff']}.
  • Figure 3: Number of minority examples inferred as majority. JTT and EIIL infer many minority examples as majority and mistakenly downweight them. Spare identifies minority groups more accurately, and correctly upweights them.
  • Figure 4: Spare-discovered spurious correlation between "green leaf" $\&$ "insect" in Restricted ImageNet.
  • Figure 5: A comparison between the losses of a two-layer network and a simple linear model on the training set, spurious features (color only), and core feature (digit only).
  • ...and 4 more figures

Theorems & Definitions (7)

  • Theorem 4.1
  • Corollary 4.1: Separability of majority and minority groups
  • Theorem 4.2
  • Theorem A.5: hu2020surprising
  • Theorem B.1
  • Corollary B.0: Separability of majority and minority groups
  • Theorem B.1