Out of spuriousity: Improving robustness to spurious correlations without group annotations
Phuong Quynh Le, Jörg Schlötterer, Christin Seifert
TL;DR
This work tackles the problem of spurious correlations causing poor generalization by proposing PruSC, a post-training approach that extracts a spurious-free subnetwork from a fully trained model. It relies on clustering in the model's representation space to identify and distort spurious-feature manifolds via a task-oriented contrastive loss, without requiring group annotations. The method constructs a class-balanced de-biasing dataset D_task through unsupervised clustering and optimizes a constrained subnetwork with L = L_mod + βL_task, followed by lightweight fine-tuning. Empirically, PruSC achieves strong worst-group accuracy on CelebA and ISIC, is competitive with annotated baselines, and demonstrates robustness to multiple spurious attributes, all while avoiding explicit spurious-feature labels. This presents a practical, annotation-free path to robust generalization via subnetworks that rely on invariant features.
Abstract
Machine learning models are known to learn spurious correlations, i.e., features having strong relations with class labels but no causal relation. Relying on those correlations leads to poor performance in the data groups without these correlations and poor generalization ability. To improve the robustness of machine learning models to spurious correlations, we propose an approach to extract a subnetwork from a fully trained network that does not rely on spurious correlations. The subnetwork is found by the assumption that data points with the same spurious attribute will be close to each other in the representation space when training with ERM, then we employ supervised contrastive loss in a novel way to force models to unlearn the spurious connections. The increase in the worst-group performance of our approach contributes to strengthening the hypothesis that there exists a subnetwork in a fully trained dense network that is responsible for using only invariant features in classification tasks, therefore erasing the influence of spurious features even in the setup of multi spurious attributes and no prior knowledge of attributes labels.
