Table of Contents
Fetching ...

Freeze then Train: Towards Provable Representation Learning under Spurious Correlations and Feature Noise

Haotian Ye, James Zou, Linjun Zhang

TL;DR

The paper investigates why last-layer probing can fail under spurious correlations and identifies non-realizable noise as a key factor. It introduces Freeze then Train (FTT), a two-stage approach that uncouples unsupervised feature acquisition from supervised retraining to preserve core features useful at test time. The authors provide theoretical guarantees showing FTT can achieve near-optimal test-time probing under broad noise conditions and validate these claims with extensive experiments on spurious-correlation benchmarks and OOD distribution shifts. The results demonstrate that FTT outperforms ERM, IRM, JTT, and CVaR-DRO across multiple datasets and noise regimes, highlighting a practical path to robust representation learning in the presence of spurious correlations.

Abstract

The existence of spurious correlations such as image backgrounds in the training environment can make empirical risk minimization (ERM) perform badly in the test environment. To address this problem, Kirichenko et al. (2022) empirically found that the core features that are related to the outcome can still be learned well even with the presence of spurious correlations. This opens a promising strategy to first train a feature learner rather than a classifier, and then perform linear probing (last layer retraining) in the test environment. However, a theoretical understanding of when and why this approach works is lacking. In this paper, we find that core features are only learned well when their associated non-realizable noise is smaller than that of spurious features, which is not necessarily true in practice. We provide both theories and experiments to support this finding and to illustrate the importance of non-realizable noise. Moreover, we propose an algorithm called Freeze then Train (FTT), that first freezes certain salient features and then trains the rest of the features using ERM. We theoretically show that FTT preserves features that are more beneficial to test time probing. Across two commonly used spurious correlation datasets, FTT outperforms ERM, IRM, JTT and CVaR-DRO, with substantial improvement in accuracy (by 4.5%) when the feature noise is large. FTT also performs better on general distribution shift benchmarks.

Freeze then Train: Towards Provable Representation Learning under Spurious Correlations and Feature Noise

TL;DR

The paper investigates why last-layer probing can fail under spurious correlations and identifies non-realizable noise as a key factor. It introduces Freeze then Train (FTT), a two-stage approach that uncouples unsupervised feature acquisition from supervised retraining to preserve core features useful at test time. The authors provide theoretical guarantees showing FTT can achieve near-optimal test-time probing under broad noise conditions and validate these claims with extensive experiments on spurious-correlation benchmarks and OOD distribution shifts. The results demonstrate that FTT outperforms ERM, IRM, JTT, and CVaR-DRO across multiple datasets and noise regimes, highlighting a practical path to robust representation learning in the presence of spurious correlations.

Abstract

The existence of spurious correlations such as image backgrounds in the training environment can make empirical risk minimization (ERM) perform badly in the test environment. To address this problem, Kirichenko et al. (2022) empirically found that the core features that are related to the outcome can still be learned well even with the presence of spurious correlations. This opens a promising strategy to first train a feature learner rather than a classifier, and then perform linear probing (last layer retraining) in the test environment. However, a theoretical understanding of when and why this approach works is lacking. In this paper, we find that core features are only learned well when their associated non-realizable noise is smaller than that of spurious features, which is not necessarily true in practice. We provide both theories and experiments to support this finding and to illustrate the importance of non-realizable noise. Moreover, we propose an algorithm called Freeze then Train (FTT), that first freezes certain salient features and then trains the rest of the features using ERM. We theoretically show that FTT preserves features that are more beneficial to test time probing. Across two commonly used spurious correlation datasets, FTT outperforms ERM, IRM, JTT and CVaR-DRO, with substantial improvement in accuracy (by 4.5%) when the feature noise is large. FTT also performs better on general distribution shift benchmarks.
Paper Structure (48 sections, 10 theorems, 61 equations, 7 figures, 5 tables, 1 algorithm)

This paper contains 48 sections, 10 theorems, 61 equations, 7 figures, 5 tables, 1 algorithm.

Key Result

Lemma 1

For all ${\bm{W}} \in \mathbb{R}^{d\times m}, {\bm{b}} \in \mathbb{R}^{m}$, we have where ${\bm{v}}^*_{tr} = (\alpha\beta^\top,(1-\alpha)\gamma^\top)^\top$ is the optimal coefficient for training, and $\alpha = \frac{\eta_{spu}^2}{\eta_{core}^2 + \eta_{spu}^2}$.

Figures (7)

  • Figure 1: The improvement of last layer retraining accuracy (%) before v.s. after ERM training on Dominoes dataset shah2020pitfalls. The model is initialized with ImageNet pretrained parameters. The x-axis and y-axis represent noise levels of the spurious and core features, respectively. ERM training helps/harms the performance when the non-realizable noise of core features is smaller/greater than that of the spurious features. Experiment settings are in \ref{['sec:experiments']}.
  • Figure 2: An illustration of our method, Freeze then Train (FTT). We start with a pretrained feature extractor (e.g. CNN) and find dataset-specific salient features using any unsupervised method like contrastive learning or PCA (the orange part). We then freeze these features and learn the rest of the features using any supervised method like ERM or a robust training algorithm (the blue part). In the test environment, the last layer is retrained. The pseudo-code can be found in Appendix \ref{['appendix:alg']}.
  • Figure 3: A toy example illustrating when and why ERM can perform well after retraining in ${\mathcal{E}_{te}}$ ($d = 2, m =1$). Assume the core feature $\beta$ is vertical and the spurious feature $\gamma$ is horizontal. Both features can predict $y$ in ${\mathcal{E}_{tr}}$, while $\gamma$ is useless in ${\mathcal{E}_{te}}$ since ${\bm{x}}_2 = \epsilon_{spu}$. We initialize our single feature ${\bm{W}}(0)$, and obtain ${\bm{W}}(t)$ after training on ${\mathcal{E}_{tr}}$. We then retrain the last layer (probing) on ${\mathcal{E}_{te}}$, i.e. rescale ${\bm{W}}(t)$ and obtain ${\bm{v}}_{test}$. When $\eta_{core} < \eta_{spu}$, ${\bm{W}}(t)$ will use $\beta$ more (the blue flow); after probing, ${\bm{v}}_{test}$ can recover $\beta$ (small approximation error) without suffering much from the spurious $\epsilon_2$ on the direction or $\gamma$ (small spurious noise error). On the contrary, when $\eta_{core}$ is large, ${\bm{W}}(t)$ will follow the red flow; this leads to a trade-off between two error terms. In this case, ERM performs much worse. Notice that flows in the figure are just for illustration. In practice, probing can either lengthen or shorten ${\bm{W}}(t)$, depending on the concrete form of two error terms.
  • Figure 4: Test-time probing accuracy gap between trained model $\mathcal{M}_{erm}$ and initialized model $\mathcal{M}_{init}$ on Waterbirds and CelebA. The x-axis is the core noise and the y-axis is the improvement of accuracy. In both datasets, the improvement of both worst group accuracy and average accuracy decrease when $\eta_{core}$ increases. In Waterbirds, large $\eta_{core}$ can even make ERM training harmful.
  • Figure 5: Worst group accuracy on Waterbirds and CelebA for FTT. The x-axis is the feature dimension in log scale.
  • ...and 2 more figures

Theorems & Definitions (10)

  • Lemma 1
  • Theorem 1: Upper Bound
  • Theorem 2: Lower Bound
  • Theorem 3: FTT Bound
  • Lemma 2
  • Lemma 3
  • Lemma 4
  • Lemma 5
  • Lemma 6
  • Lemma 7