Table of Contents
Fetching ...

Elastic Representation: Mitigating Spurious Correlations for Group Robustness

Tao Wen, Zihan Wang, Quan Zhang, Qi Lei

TL;DR

Elastic Representation (ElRep) targets spurious correlations that undermine group robustness under limited data by regularizing the latent representation with a nuclear-norm and a Frobenius-norm penalty. The method yields a leaner, more invariant feature set while preserving correlated invariants, and it can be plugged into existing training regimes without requiring group labels. Theoretical analysis shows ElRep does not harm in-distribution predictions, and empirical results across synthetic and real datasets demonstrate consistent improvements in worst-group accuracy with modest or no losses in average accuracy. Overall, ElRep offers a simple, broadly compatible approach to improve domain generalization and group robustness by shaping the learned representations rather than manipulating the classifier alone.

Abstract

Deep learning models can suffer from severe performance degradation when relying on spurious correlations between input features and labels, making the models perform well on training data but have poor prediction accuracy for minority groups. This problem arises especially when training data are limited or imbalanced. While most prior work focuses on learning invariant features (with consistent correlations to y), it overlooks the potential harm of spurious correlations between features. We hereby propose Elastic Representation (ElRep) to learn features by imposing Nuclear- and Frobenius-norm penalties on the representation from the last layer of a neural network. Similar to the elastic net, ElRep enjoys the benefits of learning important features without losing feature diversity. The proposed method is simple yet effective. It can be integrated into many deep learning approaches to mitigate spurious correlations and improve group robustness. Moreover, we theoretically show that ElRep has minimum negative impacts on in-distribution predictions. This is a remarkable advantage over approaches that prioritize minority groups at the cost of overall performance.

Elastic Representation: Mitigating Spurious Correlations for Group Robustness

TL;DR

Elastic Representation (ElRep) targets spurious correlations that undermine group robustness under limited data by regularizing the latent representation with a nuclear-norm and a Frobenius-norm penalty. The method yields a leaner, more invariant feature set while preserving correlated invariants, and it can be plugged into existing training regimes without requiring group labels. Theoretical analysis shows ElRep does not harm in-distribution predictions, and empirical results across synthetic and real datasets demonstrate consistent improvements in worst-group accuracy with modest or no losses in average accuracy. Overall, ElRep offers a simple, broadly compatible approach to improve domain generalization and group robustness by shaping the learned representations rather than manipulating the classifier alone.

Abstract

Deep learning models can suffer from severe performance degradation when relying on spurious correlations between input features and labels, making the models perform well on training data but have poor prediction accuracy for minority groups. This problem arises especially when training data are limited or imbalanced. While most prior work focuses on learning invariant features (with consistent correlations to y), it overlooks the potential harm of spurious correlations between features. We hereby propose Elastic Representation (ElRep) to learn features by imposing Nuclear- and Frobenius-norm penalties on the representation from the last layer of a neural network. Similar to the elastic net, ElRep enjoys the benefits of learning important features without losing feature diversity. The proposed method is simple yet effective. It can be integrated into many deep learning approaches to mitigate spurious correlations and improve group robustness. Moreover, we theoretically show that ElRep has minimum negative impacts on in-distribution predictions. This is a remarkable advantage over approaches that prioritize minority groups at the cost of overall performance.

Paper Structure

This paper contains 31 sections, 4 theorems, 29 equations, 5 figures, 5 tables.

Key Result

Theorem 5.1

Under Assumption assum:data and assum:subgaussian, we fix a failure probability $\delta$ and choose proper $\lambda_1,\lambda_2,\lambda_3$. Then with probability at least $1-\delta$ over training samples, the prediction difference between our approach and the ground truth satisfies: where $R=\|\theta^*\|_1$ and we omit logarithmic factors.

Figures (5)

  • Figure 1: A long-tailed Jaeger, a waterbird on a land background, from the waterbirds dataset sagawa2019distributionally. The heat maps depict the pixel contributions to bird type prediction using Grad-CAM Selvaraju_2019. From left to right are the original image, ERM, ERM with nuclear norm, and ERM with nuclear and Frobenius norms, respectively. ERM learns features including background areas. ERM with nuclear norm focuses on the head, and ERM with both norms evenly emphasizes the head and the wing.
  • Figure 2: Connections between ElRep and elastic net.
  • Figure 3: Left: The difference in the worst-group accuracy between the baseline methods with and without ElRep. The improvement is ubiquitous among all the methods compared on all the three datasets. Right: The difference in the average accuracy between the baseline methods with and without ElRep. Usually, an increase in worst-group accuracy comes with a decrease in average accuracy. Our approach can also improve the average accuracy for some baselines on the image datasets.
  • Figure 4: Accuracy per group and average accuracy against the log of $\lambda_1$ (left) and $\lambda_2$ (right). As their value increases, the accuracy of the two minority groups will gradually increase and eventually surpass the average accuracy. The trend is reversed for the two majority groups.
  • Figure 5: The two majority groups downsampled to about 1%. Reversed trends are observed.

Theorems & Definitions (7)

  • Theorem 5.1
  • Proposition 5.2
  • Lemma B.1
  • proof
  • Lemma B.2
  • proof
  • proof : Proof of Theorem 5.1