Removing Spurious Concepts from Neural Network Representations via Joint Subspace Estimation
Floris Holstege, Bram Wouters, Noud van Giersbergen, Cees Diks
TL;DR
The paper tackles spurious correlations in neural network embeddings by proposing Joint Subspace Estimation (JSE), which jointly identifies two orthogonal subspaces within last-layer representations: $\mathcal{Z}_{sp}$ for spurious concepts and $\mathcal{Z}_{mt}$ for main-task concepts. JSE estimates multiple direction vectors using a constrained optimization that enforces orthogonality and employs statistical tests on binary-cross-entropy-based predictions to decide when to stop adding vectors. Empirically, JSE outperforms existing concept-removal methods on Vision (Waterbirds, CelebA) and NLP (MultiNLI) tasks, achieving better spurious-feature removal while preserving main-task information and enabling interpretable, post-hoc analyses such as Grad-CAM. This approach improves OOD generalization and interpretability, offering practical benefits for deploying robust and transparent models.
Abstract
Out-of-distribution generalization in neural networks is often hampered by spurious correlations. A common strategy is to mitigate this by removing spurious concepts from the neural network representation of the data. Existing concept-removal methods tend to be overzealous by inadvertently eliminating features associated with the main task of the model, thereby harming model performance. We propose an iterative algorithm that separates spurious from main-task concepts by jointly identifying two low-dimensional orthogonal subspaces in the neural network representation. We evaluate the algorithm on benchmark datasets for computer vision (Waterbirds, CelebA) and natural language processing (MultiNLI), and show that it outperforms existing concept removal methods
