Table of Contents
Fetching ...

Removing Spurious Concepts from Neural Network Representations via Joint Subspace Estimation

Floris Holstege, Bram Wouters, Noud van Giersbergen, Cees Diks

TL;DR

The paper tackles spurious correlations in neural network embeddings by proposing Joint Subspace Estimation (JSE), which jointly identifies two orthogonal subspaces within last-layer representations: $\mathcal{Z}_{sp}$ for spurious concepts and $\mathcal{Z}_{mt}$ for main-task concepts. JSE estimates multiple direction vectors using a constrained optimization that enforces orthogonality and employs statistical tests on binary-cross-entropy-based predictions to decide when to stop adding vectors. Empirically, JSE outperforms existing concept-removal methods on Vision (Waterbirds, CelebA) and NLP (MultiNLI) tasks, achieving better spurious-feature removal while preserving main-task information and enabling interpretable, post-hoc analyses such as Grad-CAM. This approach improves OOD generalization and interpretability, offering practical benefits for deploying robust and transparent models.

Abstract

Out-of-distribution generalization in neural networks is often hampered by spurious correlations. A common strategy is to mitigate this by removing spurious concepts from the neural network representation of the data. Existing concept-removal methods tend to be overzealous by inadvertently eliminating features associated with the main task of the model, thereby harming model performance. We propose an iterative algorithm that separates spurious from main-task concepts by jointly identifying two low-dimensional orthogonal subspaces in the neural network representation. We evaluate the algorithm on benchmark datasets for computer vision (Waterbirds, CelebA) and natural language processing (MultiNLI), and show that it outperforms existing concept removal methods

Removing Spurious Concepts from Neural Network Representations via Joint Subspace Estimation

TL;DR

The paper tackles spurious correlations in neural network embeddings by proposing Joint Subspace Estimation (JSE), which jointly identifies two orthogonal subspaces within last-layer representations: for spurious concepts and for main-task concepts. JSE estimates multiple direction vectors using a constrained optimization that enforces orthogonality and employs statistical tests on binary-cross-entropy-based predictions to decide when to stop adding vectors. Empirically, JSE outperforms existing concept-removal methods on Vision (Waterbirds, CelebA) and NLP (MultiNLI) tasks, achieving better spurious-feature removal while preserving main-task information and enabling interpretable, post-hoc analyses such as Grad-CAM. This approach improves OOD generalization and interpretability, offering practical benefits for deploying robust and transparent models.

Abstract

Out-of-distribution generalization in neural networks is often hampered by spurious correlations. A common strategy is to mitigate this by removing spurious concepts from the neural network representation of the data. Existing concept-removal methods tend to be overzealous by inadvertently eliminating features associated with the main task of the model, thereby harming model performance. We propose an iterative algorithm that separates spurious from main-task concepts by jointly identifying two low-dimensional orthogonal subspaces in the neural network representation. We evaluate the algorithm on benchmark datasets for computer vision (Waterbirds, CelebA) and natural language processing (MultiNLI), and show that it outperforms existing concept removal methods
Paper Structure (39 sections, 38 equations, 21 figures, 14 tables, 2 algorithms)

This paper contains 39 sections, 38 equations, 21 figures, 14 tables, 2 algorithms.

Figures (21)

  • Figure 1: High-level overview of Joint Subspace Estimation (JSE) for concept removal: the input $\boldsymbol{x}$ is fed through a neural network $f(\boldsymbol{x})$, from which we can extract the vector representation $\boldsymbol{z}$. Within the vector representation, two orthogonal subspaces are identified: one related to the spurious concept (e.g. the background), and one to the main-task concept (e.g. animal type).
  • Figure 2: Illustration of JSE, in comparison to INLP: based on the $d({=}20)$-dimensional Toy dataset (see Section \ref{['sec:datasets']}) with $\rho =0.8$ and sample size $n =$2,000. Two-dimensional slices of $\boldsymbol{z}$ are shown. Panels A and D have the spurious feature on the x-axis and the main-task feature on the y-axis. The remaining panels show the axes that best separate the main-task labels. JSE identifies a single spurious vector (panel D) and the remaining class separation is attributed to the main-task concept (panel E). INLP identifies (superpositions of) the main-task and spurious directions as spurious (panels A and B), and the main-task information is removed (panel C).
  • Figure 3: OOD generalization, compared to other concept-removal methods: We plot the (worst-group) accuracy on a test set without spurious correlation, as a function of the spurious correlation in the training set ($\rho$ for the Toy dataset, $p_\mathrm{train}(y_\mathrm{mt} = y | y_\mathrm{sp} = y )$ for the other datasets). Averages based on 100, 5, 5 and 5 runs, respectively. The shaded area reflects the 95% confidence interval.
  • Figure 4: Ability to reconstruct main-task and spurious concept features after concept removal. We show the mean-squared error (MSE) of predicting $\boldsymbol{z}_\mathrm{mt}, \boldsymbol{z}_\mathrm{sp}$ via OLS on the transformed embeddings. The dotted line on the right plot indicates the MSE when there is no information of $\boldsymbol{z}_\mathrm{sp}$ left. Averages based on 100 runs, and shaded area reflects the 95% confidence interval.
  • Figure 5: Grad-CAM for the last layer of Resnet-50 predicting the main-task label: Red (green) patches indicate a contribution towards a prediction $y_\mathrm{mt}=0$ ($y_\mathrm{mt}=1$).
  • ...and 16 more figures