Improving Group Robustness on Spurious Correlation Requires Preciser Group Inference
Yujin Han, Difan Zou
TL;DR
This work addresses the challenge that ERM models exploit spurious correlations leading to poor worst-group performance when group labels are unavailable. It introduces GIC, a principled method that learns a spurious-attribute predictor from a comparison dataset with shifted group distributions, optimizing a Correlation Term $I(y^{tr}; \hat{y}^{tr}_{s,\mathbf{w}})$ and a Spurious Term $KL(\mathbb{P}(y^{tr}|\hat{y}^{tr}_{s,\mathbf{w}}) || \mathbb{P}(y^{c}|\hat{y}^{c}_{s,\mathbf{w}}))$ with weight $\gamma$, and then uses inferred groups $\hat{g}=(y, \hat{y}_{s})$ in downstream invariant-learning methods. GIC supports unlabeled or labeled comparison data, can be combined with Mixup, GroupDRO, Subsample, and Upsample, and demonstrates improved worst-group accuracy across synthetic and real-world datasets, including notable semantic-consistency patterns in misclassifications that aid decoupling spurious attributes from labels. The results show that accurate group inference via data comparison substantially narrows the gap to oracle-group-label performance, highlighting a practical pathway to robust models when group information is scarce. Overall, GIC offers flexible, scalable Group Inference that enhances spurious-correlation mitigation through principled data comparison and integration with invariant learning techniques.
Abstract
Standard empirical risk minimization (ERM) models may prioritize learning spurious correlations between spurious features and true labels, leading to poor accuracy on groups where these correlations do not hold. Mitigating this issue often requires expensive spurious attribute (group) labels or relies on trained ERM models to infer group labels when group information is unavailable. However, the significant performance gap in worst-group accuracy between using pseudo group labels and using oracle group labels inspires us to consider further improving group robustness through preciser group inference. Therefore, we propose GIC, a novel method that accurately infers group labels, resulting in improved worst-group performance. GIC trains a spurious attribute classifier based on two key properties of spurious correlations: (1) high correlation between spurious attributes and true labels, and (2) variability in this correlation between datasets with different group distributions. Empirical studies on multiple datasets demonstrate the effectiveness of GIC in inferring group labels, and combining GIC with various downstream invariant learning methods improves worst-group accuracy, showcasing its powerful flexibility. Additionally, through analyzing the misclassifications in GIC, we identify an interesting phenomenon called semantic consistency, which may contribute to better decoupling the association between spurious attributes and labels, thereby mitigating spurious correlation. The code for GIC is available at https://github.com/yujinhanml/GIC.
