Bias Amplification Enhances Minority Group Performance
Gaotang Li, Jiarui Liu, Wei Hu
TL;DR
This work tackles spurious correlations causing minority subgroups to be poorly classified by proposing BAM, a two-stage training algorithm that bias-amplifies training via per-example auxiliary variables and then reweights the model's mistakes for rebalanced training, all without requiring training-time group annotations. A key contribution is the ClassDiff stopping criterion, which allows annotation-free tuning by correlating smaller class-difference in validation performance with higher worst-group accuracy. BAM demonstrates competitive or superior worst-group performance on Waterbirds, CelebA, MultiNLI, and CivilComments-WILDS, with robust results across hyperparameters and datasets, and shows that continuing training in Stage 2 (One-M) generally outperforms training a second model. The approach offers a practical path to improving fairness in real-world settings where group labels are scarce, while also enabling insight into the bias-amplification mechanism via auxiliary variables.
Abstract
Neural networks produced by standard training are known to suffer from poor accuracy on rare subgroups despite achieving high accuracy on average, due to the correlations between certain spurious features and labels. Previous approaches based on worst-group loss minimization (e.g. Group-DRO) are effective in improving worse-group accuracy but require expensive group annotations for all the training samples. In this paper, we focus on the more challenging and realistic setting where group annotations are only available on a small validation set or are not available at all. We propose BAM, a novel two-stage training algorithm: in the first stage, the model is trained using a bias amplification scheme via introducing a learnable auxiliary variable for each training sample; in the second stage, we upweight the samples that the bias-amplified model misclassifies, and then continue training the same model on the reweighted dataset. Empirically, BAM achieves competitive performance compared with existing methods evaluated on spurious correlation benchmarks in computer vision and natural language processing. Moreover, we find a simple stopping criterion based on minimum class accuracy difference that can remove the need for group annotations, with little or no loss in worst-group accuracy. We perform extensive analyses and ablations to verify the effectiveness and robustness of our algorithm in varying class and group imbalance ratios.
