Table of Contents
Fetching ...

Spurious Correlation-Aware Embedding Regularization for Worst-Group Robustness

Subeen Park, Joowang Kim, Hakyung Lee, Sunjae Yoo, Kyungwoo Song

TL;DR

SCER addresses spurious correlations that degrade worst-group robustness under subpopulation shifts by directly regularizing the embedding space. It decomposes worst-group error into spurious and core components using group-wise mean embeddings and a Sigma-norm, then optimizes an embedding loss that penalizes spurious alignment while promoting core-aligned representations. Theoretical analysis links worst-group error to the product of alignment with spurious directions and their magnitudes, guiding the embedding regularization; empirically, SCER achieves state-of-the-art worst-group accuracy across vision and language benchmarks and remains effective when environment labels are inferred. The approach offers a practical, single-stage method for robust generalization under distribution shifts with strong cross-domain performance.

Abstract

Deep learning models achieve strong performance across various domains but often rely on spurious correlations, making them vulnerable to distribution shifts. This issue is particularly severe in subpopulation shift scenarios, where models struggle in underrepresented groups. While existing methods have made progress in mitigating this issue, their performance gains are still constrained. They lack a rigorous theoretical framework connecting the embedding space representations with worst-group error. To address this limitation, we propose Spurious Correlation-Aware Embedding Regularization for Worst-Group Robustness (SCER), a novel approach that directly regularizes feature representations to suppress spurious cues. We show theoretically that worst-group error is influenced by how strongly the classifier relies on spurious versus core directions, identified from differences in group-wise mean embeddings across domains and classes. By imposing theoretical constraints at the embedding level, SCER encourages models to focus on core features while reducing sensitivity to spurious patterns. Through systematic evaluation on multiple vision and language, we show that SCER outperforms prior state-of-the-art studies in worst-group accuracy. Our code is available at \href{https://github.com/MLAI-Yonsei/SCER}{https://github.com/MLAI-Yonsei/SCER}.

Spurious Correlation-Aware Embedding Regularization for Worst-Group Robustness

TL;DR

SCER addresses spurious correlations that degrade worst-group robustness under subpopulation shifts by directly regularizing the embedding space. It decomposes worst-group error into spurious and core components using group-wise mean embeddings and a Sigma-norm, then optimizes an embedding loss that penalizes spurious alignment while promoting core-aligned representations. Theoretical analysis links worst-group error to the product of alignment with spurious directions and their magnitudes, guiding the embedding regularization; empirically, SCER achieves state-of-the-art worst-group accuracy across vision and language benchmarks and remains effective when environment labels are inferred. The approach offers a practical, single-stage method for robust generalization under distribution shifts with strong cross-domain performance.

Abstract

Deep learning models achieve strong performance across various domains but often rely on spurious correlations, making them vulnerable to distribution shifts. This issue is particularly severe in subpopulation shift scenarios, where models struggle in underrepresented groups. While existing methods have made progress in mitigating this issue, their performance gains are still constrained. They lack a rigorous theoretical framework connecting the embedding space representations with worst-group error. To address this limitation, we propose Spurious Correlation-Aware Embedding Regularization for Worst-Group Robustness (SCER), a novel approach that directly regularizes feature representations to suppress spurious cues. We show theoretically that worst-group error is influenced by how strongly the classifier relies on spurious versus core directions, identified from differences in group-wise mean embeddings across domains and classes. By imposing theoretical constraints at the embedding level, SCER encourages models to focus on core features while reducing sensitivity to spurious patterns. Through systematic evaluation on multiple vision and language, we show that SCER outperforms prior state-of-the-art studies in worst-group accuracy. Our code is available at \href{https://github.com/MLAI-Yonsei/SCER}{https://github.com/MLAI-Yonsei/SCER}.

Paper Structure

This paper contains 59 sections, 5 theorems, 53 equations, 4 figures, 11 tables, 1 algorithm.

Key Result

Theorem 1

Consider a classification problem where each data point $x \in \mathcal{X}$ is associated with a label $\mathcal{Y} = \{y_{-1}, y_{+1}\}$ and a domain $\mathcal{D} = \{d_R, d_G\}$. Each pair $(y, d) \in \mathcal{Y} \times \mathcal{D}$ defines a subpopulation. We assume that the data follows a group-

Figures (4)

  • Figure 1: Overview of the Spurious Correlation-Aware Embedding Regularization (SCER) framework. SCER distinguishes spurious and core features via group-wise embedding differences, then regularizes to minimize spurious loss and maximize core loss for domain bias robustness.
  • Figure 2: Overview of the SCER training framework. Input data is encoded into embeddings, from which group-wise mean embeddings are computed to derive spurious and core directions. The framework combines worst-group classification loss with embedding regularization that penalizes spurious alignment and promotes core alignment.
  • Figure 3: Scatter plots showing worst-group accuracy vs spurious and core metrics. Spurious Loss shows negative correlation, while Core Loss shows positive correlation.
  • Figure 4: Comparison of Grad-CAM visualizations on the Waterbirds dataset. SCER directs attention to meaningful features, reducing focus on spurious regions.

Theorems & Definitions (9)

  • Theorem 1: Worst-group Error Decomposition
  • Proposition 1: ERM solution under Cross-Entropy Loss
  • proof
  • Theorem 1: Worst-group error Decomposition
  • proof
  • Proposition 2: Multiclass/Multidomain ERM under Cross-Entropy Loss
  • proof
  • Theorem 2: Multiclass/Multidomain Worst-Group error
  • proof