Table of Contents
Fetching ...

Bias Amplification Enhances Minority Group Performance

Gaotang Li, Jiarui Liu, Wei Hu

TL;DR

This work tackles spurious correlations causing minority subgroups to be poorly classified by proposing BAM, a two-stage training algorithm that bias-amplifies training via per-example auxiliary variables and then reweights the model's mistakes for rebalanced training, all without requiring training-time group annotations. A key contribution is the ClassDiff stopping criterion, which allows annotation-free tuning by correlating smaller class-difference in validation performance with higher worst-group accuracy. BAM demonstrates competitive or superior worst-group performance on Waterbirds, CelebA, MultiNLI, and CivilComments-WILDS, with robust results across hyperparameters and datasets, and shows that continuing training in Stage 2 (One-M) generally outperforms training a second model. The approach offers a practical path to improving fairness in real-world settings where group labels are scarce, while also enabling insight into the bias-amplification mechanism via auxiliary variables.

Abstract

Neural networks produced by standard training are known to suffer from poor accuracy on rare subgroups despite achieving high accuracy on average, due to the correlations between certain spurious features and labels. Previous approaches based on worst-group loss minimization (e.g. Group-DRO) are effective in improving worse-group accuracy but require expensive group annotations for all the training samples. In this paper, we focus on the more challenging and realistic setting where group annotations are only available on a small validation set or are not available at all. We propose BAM, a novel two-stage training algorithm: in the first stage, the model is trained using a bias amplification scheme via introducing a learnable auxiliary variable for each training sample; in the second stage, we upweight the samples that the bias-amplified model misclassifies, and then continue training the same model on the reweighted dataset. Empirically, BAM achieves competitive performance compared with existing methods evaluated on spurious correlation benchmarks in computer vision and natural language processing. Moreover, we find a simple stopping criterion based on minimum class accuracy difference that can remove the need for group annotations, with little or no loss in worst-group accuracy. We perform extensive analyses and ablations to verify the effectiveness and robustness of our algorithm in varying class and group imbalance ratios.

Bias Amplification Enhances Minority Group Performance

TL;DR

This work tackles spurious correlations causing minority subgroups to be poorly classified by proposing BAM, a two-stage training algorithm that bias-amplifies training via per-example auxiliary variables and then reweights the model's mistakes for rebalanced training, all without requiring training-time group annotations. A key contribution is the ClassDiff stopping criterion, which allows annotation-free tuning by correlating smaller class-difference in validation performance with higher worst-group accuracy. BAM demonstrates competitive or superior worst-group performance on Waterbirds, CelebA, MultiNLI, and CivilComments-WILDS, with robust results across hyperparameters and datasets, and shows that continuing training in Stage 2 (One-M) generally outperforms training a second model. The approach offers a practical path to improving fairness in real-world settings where group labels are scarce, while also enabling insight into the bias-amplification mechanism via auxiliary variables.

Abstract

Neural networks produced by standard training are known to suffer from poor accuracy on rare subgroups despite achieving high accuracy on average, due to the correlations between certain spurious features and labels. Previous approaches based on worst-group loss minimization (e.g. Group-DRO) are effective in improving worse-group accuracy but require expensive group annotations for all the training samples. In this paper, we focus on the more challenging and realistic setting where group annotations are only available on a small validation set or are not available at all. We propose BAM, a novel two-stage training algorithm: in the first stage, the model is trained using a bias amplification scheme via introducing a learnable auxiliary variable for each training sample; in the second stage, we upweight the samples that the bias-amplified model misclassifies, and then continue training the same model on the reweighted dataset. Empirically, BAM achieves competitive performance compared with existing methods evaluated on spurious correlation benchmarks in computer vision and natural language processing. Moreover, we find a simple stopping criterion based on minimum class accuracy difference that can remove the need for group annotations, with little or no loss in worst-group accuracy. We perform extensive analyses and ablations to verify the effectiveness and robustness of our algorithm in varying class and group imbalance ratios.
Paper Structure (34 sections, 6 equations, 10 figures, 7 tables, 1 algorithm)

This paper contains 34 sections, 6 equations, 10 figures, 7 tables, 1 algorithm.

Figures (10)

  • Figure 1: Using Grad-CAM selvaraju2017grad to visualize the effect of bias amplification and rebalanced training stages, where the classifier heavily relies on the background information to make predictions after bias amplification but focuses on the useful feature (bird) itself after the rebalanced training stage.
  • Figure 2: Distributions of the auxiliary variable w.r.t. the waterbird class (left) and landbird class (right) on the training set at $T=20$. We use two distinct colors to illustrate the distributions of two groups in each class. The coordinates of data sample $i$ relative to the origin show the bias learned by the auxiliary variable.
  • Figure 3: Epochs 1 and 100 in Stage 1 on Waterbirds. Logit 0 corresponds to the prediction on the waterbird class, and logit 1 corresponds to landbird. The group sizes are 1800, 200, 200, and 1800 in order.
  • Figure 4: Epochs 1 and 4 in stage 1 on Colored-MNIST. Logit 0 corresponds to the prediction on class 0, and logit 1 corresponds to class 1.
  • Figure 5: Relation between absolute valdiation class difference and worst-group validation accuracy in Stage 2 on Waterbirds, CelebA, CivilComments-WILDS, and MultiNLI. It can be observed that minimizing absolute validation class difference is roughly equivalent to maximizing worst-group accuracy in the validation set. Each dataset uses the same hyperparameters as employed in \ref{['tab:main_1']}. Each line represents the value averaged over 3 different seeds and the shade represents the standard deviation.
  • ...and 5 more figures

Theorems & Definitions (2)

  • Claim 4.1: proved in \ref{['appen:classdiff-proof']}
  • proof : Proof of \ref{['claim:classdiff']}