Table of Contents
Fetching ...

Not Only the Last-Layer Features for Spurious Correlations: All Layer Deep Feature Reweighting

Humza Wajid Hameed, Geraldin Nanfack, Eugene Belilovsky

TL;DR

The paper tackles spurious correlations that degrade worst-group performance in ERM-trained models. It introduces H2T-DFR, a three-stage approach that uses Head2Toe to select transferable features from all network layers and then applies Deep Feature Reweighting on a balanced validation set to emphasize robust features. Empirical results on CelebA, Waterbirds, and HAM10000 (with a ResNet-50 backbone) show improvements in worst-group accuracy for CelebA (~2.6%) and HAM10000 (~2.4%), while Waterbirds remains largely unchanged, with mean group accuracy staying comparable. The findings suggest that integrating multi-layer feature selection with balanced-group retraining can meaningfully boost robustness to spurious correlations in real-world benchmarks.

Abstract

Spurious correlations are a major source of errors for machine learning models, in particular when aiming for group-level fairness. It has been recently shown that a powerful approach to combat spurious correlations is to re-train the last layer on a balanced validation dataset, isolating robust features for the predictor. However, key attributes can sometimes be discarded by neural networks towards the last layer. In this work, we thus consider retraining a classifier on a set of features derived from all layers. We utilize a recently proposed feature selection strategy to select unbiased features from all the layers. We observe this approach gives significant improvements in worst-group accuracy on several standard benchmarks.

Not Only the Last-Layer Features for Spurious Correlations: All Layer Deep Feature Reweighting

TL;DR

The paper tackles spurious correlations that degrade worst-group performance in ERM-trained models. It introduces H2T-DFR, a three-stage approach that uses Head2Toe to select transferable features from all network layers and then applies Deep Feature Reweighting on a balanced validation set to emphasize robust features. Empirical results on CelebA, Waterbirds, and HAM10000 (with a ResNet-50 backbone) show improvements in worst-group accuracy for CelebA (~2.6%) and HAM10000 (~2.4%), while Waterbirds remains largely unchanged, with mean group accuracy staying comparable. The findings suggest that integrating multi-layer feature selection with balanced-group retraining can meaningfully boost robustness to spurious correlations in real-world benchmarks.

Abstract

Spurious correlations are a major source of errors for machine learning models, in particular when aiming for group-level fairness. It has been recently shown that a powerful approach to combat spurious correlations is to re-train the last layer on a balanced validation dataset, isolating robust features for the predictor. However, key attributes can sometimes be discarded by neural networks towards the last layer. In this work, we thus consider retraining a classifier on a set of features derived from all layers. We utilize a recently proposed feature selection strategy to select unbiased features from all the layers. We observe this approach gives significant improvements in worst-group accuracy on several standard benchmarks.
Paper Structure (7 sections, 1 equation, 2 figures, 4 tables)

This paper contains 7 sections, 1 equation, 2 figures, 4 tables.

Figures (2)

  • Figure 1: Top Row - H2T-DFR Illustration of the different training phases for H2T-DFR. A pre-trained network is tuned on a target task using unbalanced data followed by Head2Toe feature selection with balanced training data. Balanced data consists of equal counts of group $G_i$, where each group is a unique combination of target and spurious features. Lastly, a classifier comprised of selected features is trained on an unseen validation dataset. Bottom Row - DFR Illustrates the DFR baseline method in comparison, which excludes any feature selection
  • Figure 2: HAM10000 and H2T-DFR: Layerwise proportion of features extracted among the overall top 5% features selected using balanced (top) and unbalanced data (bottom) for feature selection. Layer depth displayed on the x-axis illustrates a minimal amount of features selected from early layers while most attributes are retained from deeper layers.