Table of Contents
Fetching ...

The Group Robustness is in the Details: Revisiting Finetuning under Spurious Correlations

Tyler LaBonte, John C. Hill, Xinchen Zhang, Vidya Muthukumar, Abhishek Kumar

TL;DR

This paper identifies surprising and nuanced behavior of finetuned models on worst-group accuracy via comprehensive experiments on four well-established benchmarks across vision and language tasks and proposes a mixture method which can outperform both class-balancing techniques.

Abstract

Modern machine learning models are prone to over-reliance on spurious correlations, which can often lead to poor performance on minority groups. In this paper, we identify surprising and nuanced behavior of finetuned models on worst-group accuracy via comprehensive experiments on four well-established benchmarks across vision and language tasks. We first show that the commonly used class-balancing techniques of mini-batch upsampling and loss upweighting can induce a decrease in worst-group accuracy (WGA) with training epochs, leading to performance no better than without class-balancing. While in some scenarios, removing data to create a class-balanced subset is more effective, we show this depends on group structure and propose a mixture method which can outperform both techniques. Next, we show that scaling pretrained models is generally beneficial for worst-group accuracy, but only in conjunction with appropriate class-balancing. Finally, we identify spectral imbalance in finetuning features as a potential source of group disparities -- minority group covariance matrices incur a larger spectral norm than majority groups once conditioned on the classes. Our results show more nuanced interactions of modern finetuned models with group robustness than was previously known. Our code is available at https://github.com/tmlabonte/revisiting-finetuning.

The Group Robustness is in the Details: Revisiting Finetuning under Spurious Correlations

TL;DR

This paper identifies surprising and nuanced behavior of finetuned models on worst-group accuracy via comprehensive experiments on four well-established benchmarks across vision and language tasks and proposes a mixture method which can outperform both class-balancing techniques.

Abstract

Modern machine learning models are prone to over-reliance on spurious correlations, which can often lead to poor performance on minority groups. In this paper, we identify surprising and nuanced behavior of finetuned models on worst-group accuracy via comprehensive experiments on four well-established benchmarks across vision and language tasks. We first show that the commonly used class-balancing techniques of mini-batch upsampling and loss upweighting can induce a decrease in worst-group accuracy (WGA) with training epochs, leading to performance no better than without class-balancing. While in some scenarios, removing data to create a class-balanced subset is more effective, we show this depends on group structure and propose a mixture method which can outperform both techniques. Next, we show that scaling pretrained models is generally beneficial for worst-group accuracy, but only in conjunction with appropriate class-balancing. Finally, we identify spectral imbalance in finetuning features as a potential source of group disparities -- minority group covariance matrices incur a larger spectral norm than majority groups once conditioned on the classes. Our results show more nuanced interactions of modern finetuned models with group robustness than was previously known. Our code is available at https://github.com/tmlabonte/revisiting-finetuning.
Paper Structure (31 sections, 1 equation, 18 figures, 5 tables)

This paper contains 31 sections, 1 equation, 18 figures, 5 tables.

Figures (18)

  • Figure 1: Class-balanced upsampling and upweighting experience catastrophic collapse. We compare subsetting, wherein data is removed to set every class to the same size as the smallest class, upsampling, wherein the sampling probabilities of each class are adjusted so that the mini-batches are class-balanced in expectation, and upweighting, wherein the loss for the smaller classes is scaled by the class-imbalance ratio. We observe a catastrophic collapse over the course of training of upsampling and upweighting on CelebA and CivilComments, the more class-imbalanced datasets. Subsetting reduces WGA on Waterbirds because it removes data from the small minority group within the majority class. MultiNLI is class-balanced a priori, so we do not include it here.
  • Figure 2: Mixture balancing mitigates catastrophic collapse of upsampling and upweighting. We propose a class-balanced mixture method, which combines subsetting and upsampling by first drawing a class-imbalanced subset uniformly at random from the dataset, then adjusting sampling probabilities so that mini-batches are balanced in expectation. Our method increases exposure to majority class data without over-sampling the minority class. Remarkably, mixture balancing outperforms all three class-balancing methods on Waterbirds and CivilComments, and while it does not outperform subsetting on CelebA, it significantly alleviates the WGA collapse experienced by upsampling and upweighting. MultiNLI is class-balanced a priori, so we do not include it here.
  • Figure 3: Scaling class-balanced pretrained models can improve worst-group accuracy. We finetune each model size starting from pretrained checkpoints and plot the test worst-group accuracy (WGA) as well as the interpolation threshold, where model reaches $100\%$ training accuracy. We find model scaling is generally beneficial for WGA only in conjunction with appropriate class-balancing, and scaling on imbalanced datasets or with the wrong method can harm robustness. Note MultiNLI is class-balanced a priori and is not interpolated. See Appendix \ref{['app:accuracy']} for training accuracy plots.
  • Figure 4: Class-balancing greatly affects ResNet scaling results of pham2021effect. We contrast the ResNet scaling behavior of pham2021effect --- who do not use class-balancing --- to the scaling of class-balanced ResNets. We finetune each model size starting from pretrained checkpoints and plot the test worst-group accuracy (WGA), as well as the interpolation threshold, where the model reaches $100\%$ training accuracy. On Waterbirds, we find that class-balancing enables a much more beneficial trend during model scaling. On CelebA, class-balancing greatly increases baseline WGA but does not affect scaling behavior (in contrast to the ConvNeXt-V2 plots in Figure \ref{['fig:scaling']}). We use SGD for last-layer training and AdamW for full finetuning. See Appendix \ref{['app:accuracy']} for training accuracy plots.
  • Figure 5: Group disparities are visible in the top eigenvalues of the group covariance matrices. We visualize the mean, across $3$ experimental trials, of the top $10$ eigenvalues of the group covariance matrices for a ConvNeXt-V2 Nano finetuned on Waterbirds and CelebA and a BERT Small finetuned on CivilComments and MultiNLI. The standard deviations are omitted for clarity. The models are finetuned using the best class-balancing method from Section \ref{['sec:balancing']} for each dataset. The group numbers are detailed in Table \ref{['tab:data']} and the minority groups within each class are denoted with an asterisk. The largest $\lambda_1$ in each case belongs to a minority group, though it may not be the worst group, and minority group eigenvalues are overall larger than majority group eigenvalues within the same class.
  • ...and 13 more figures