Table of Contents
Fetching ...

Improving Group Robustness on Spurious Correlation Requires Preciser Group Inference

Yujin Han, Difan Zou

TL;DR

This work addresses the challenge that ERM models exploit spurious correlations leading to poor worst-group performance when group labels are unavailable. It introduces GIC, a principled method that learns a spurious-attribute predictor from a comparison dataset with shifted group distributions, optimizing a Correlation Term $I(y^{tr}; \hat{y}^{tr}_{s,\mathbf{w}})$ and a Spurious Term $KL(\mathbb{P}(y^{tr}|\hat{y}^{tr}_{s,\mathbf{w}}) || \mathbb{P}(y^{c}|\hat{y}^{c}_{s,\mathbf{w}}))$ with weight $\gamma$, and then uses inferred groups $\hat{g}=(y, \hat{y}_{s})$ in downstream invariant-learning methods. GIC supports unlabeled or labeled comparison data, can be combined with Mixup, GroupDRO, Subsample, and Upsample, and demonstrates improved worst-group accuracy across synthetic and real-world datasets, including notable semantic-consistency patterns in misclassifications that aid decoupling spurious attributes from labels. The results show that accurate group inference via data comparison substantially narrows the gap to oracle-group-label performance, highlighting a practical pathway to robust models when group information is scarce. Overall, GIC offers flexible, scalable Group Inference that enhances spurious-correlation mitigation through principled data comparison and integration with invariant learning techniques.

Abstract

Standard empirical risk minimization (ERM) models may prioritize learning spurious correlations between spurious features and true labels, leading to poor accuracy on groups where these correlations do not hold. Mitigating this issue often requires expensive spurious attribute (group) labels or relies on trained ERM models to infer group labels when group information is unavailable. However, the significant performance gap in worst-group accuracy between using pseudo group labels and using oracle group labels inspires us to consider further improving group robustness through preciser group inference. Therefore, we propose GIC, a novel method that accurately infers group labels, resulting in improved worst-group performance. GIC trains a spurious attribute classifier based on two key properties of spurious correlations: (1) high correlation between spurious attributes and true labels, and (2) variability in this correlation between datasets with different group distributions. Empirical studies on multiple datasets demonstrate the effectiveness of GIC in inferring group labels, and combining GIC with various downstream invariant learning methods improves worst-group accuracy, showcasing its powerful flexibility. Additionally, through analyzing the misclassifications in GIC, we identify an interesting phenomenon called semantic consistency, which may contribute to better decoupling the association between spurious attributes and labels, thereby mitigating spurious correlation. The code for GIC is available at https://github.com/yujinhanml/GIC.

Improving Group Robustness on Spurious Correlation Requires Preciser Group Inference

TL;DR

This work addresses the challenge that ERM models exploit spurious correlations leading to poor worst-group performance when group labels are unavailable. It introduces GIC, a principled method that learns a spurious-attribute predictor from a comparison dataset with shifted group distributions, optimizing a Correlation Term and a Spurious Term with weight , and then uses inferred groups in downstream invariant-learning methods. GIC supports unlabeled or labeled comparison data, can be combined with Mixup, GroupDRO, Subsample, and Upsample, and demonstrates improved worst-group accuracy across synthetic and real-world datasets, including notable semantic-consistency patterns in misclassifications that aid decoupling spurious attributes from labels. The results show that accurate group inference via data comparison substantially narrows the gap to oracle-group-label performance, highlighting a practical pathway to robust models when group information is scarce. Overall, GIC offers flexible, scalable Group Inference that enhances spurious-correlation mitigation through principled data comparison and integration with invariant learning techniques.

Abstract

Standard empirical risk minimization (ERM) models may prioritize learning spurious correlations between spurious features and true labels, leading to poor accuracy on groups where these correlations do not hold. Mitigating this issue often requires expensive spurious attribute (group) labels or relies on trained ERM models to infer group labels when group information is unavailable. However, the significant performance gap in worst-group accuracy between using pseudo group labels and using oracle group labels inspires us to consider further improving group robustness through preciser group inference. Therefore, we propose GIC, a novel method that accurately infers group labels, resulting in improved worst-group performance. GIC trains a spurious attribute classifier based on two key properties of spurious correlations: (1) high correlation between spurious attributes and true labels, and (2) variability in this correlation between datasets with different group distributions. Empirical studies on multiple datasets demonstrate the effectiveness of GIC in inferring group labels, and combining GIC with various downstream invariant learning methods improves worst-group accuracy, showcasing its powerful flexibility. Additionally, through analyzing the misclassifications in GIC, we identify an interesting phenomenon called semantic consistency, which may contribute to better decoupling the association between spurious attributes and labels, thereby mitigating spurious correlation. The code for GIC is available at https://github.com/yujinhanml/GIC.
Paper Structure (31 sections, 5 theorems, 28 equations, 13 figures, 8 tables, 1 algorithm)

This paper contains 31 sections, 5 theorems, 28 equations, 13 figures, 8 tables, 1 algorithm.

Key Result

Theorem 3.1

[Lower Bound of Spurious Term without $y^{c}$] Given representations $\mathbf{z}^{tr}$ and $\mathbf{z}^{c}$, the spurious term is lower bounded by the following expression as:

Figures (13)

  • Figure 1: Decision boundary visualization. $f_{\mathrm{ERM}}$ underperforms $f_{\mathrm{GIC}}$ in recognizing spurious attributes and $f_{\mathrm{robust}}$ in identifying invariant attributes. Classes $0$ and $1$ are represented by colors (red and blue), respectively, with shapes marking spurious attributes.
  • Figure 2: Evaluation of group label inference. Compared to baseline methods such as ERM and EIIL, GIC significantly improves the recall for minority group label inference while maintaining a relatively high precision.
  • Figure 3: Visualization of evaluated datasets with minority groups marked by red boxes. The spurious attribute and targets exhibit strong spurious correlations, while these correlations typically does not hold for minority groups.
  • Figure 4: Misclassified samples on CelebA. The semantic consistency in GIC leads to the misclassification of women with short hair (a typical characteristic of males) as males (91.7%), and men with long hair (a typical characteristic of females) as females (8.3%).
  • Figure 5: GIC generates better mixed images. By leveraging the high semantic consistency in image recognition, spurious attributes and true labels are decoupled in mixed images generated by GIC. Various mixing techniques are employed to handle different datasets, as detailed in \ref{['Downstream invariant learning methods']}.
  • ...and 8 more figures

Theorems & Definitions (10)

  • Theorem 3.1
  • Lemma 1.1
  • proof
  • Lemma 1.3
  • proof
  • Lemma 1.4
  • proof
  • Theorem 1.5
  • proof
  • Example 2.1