Table of Contents
Fetching ...

On the Unreasonable Effectiveness of Last-layer Retraining

John C. Hill, Tyler LaBonte, Xinchen Zhang, Vidya Muthukumar

TL;DR

ERM often leverages spurious correlations that harm minority-group performance. This work investigates last-layer retraining (LLR) as an efficient remedy, testing whether neural collapse or implicit bias explains its effectiveness. Across four benchmarks, neural collapse did not consistently occur during ERM, and LLR’s gains correlate strongly with improved group balance in the held-out data rather than margin-based dynamics. The findings show that CB-LLR and AFR achieve robust worst-group performance primarily by implicit or explicit group-balancing, guiding practical use of LLR when group annotations are limited and highlighting the importance of data balance in held-out sets.

Abstract

Last-layer retraining (LLR) methods -- wherein the last layer of a neural network is reinitialized and retrained on a held-out set following ERM training -- have garnered interest as an efficient approach to rectify dependence on spurious correlations and improve performance on minority groups. Surprisingly, LLR has been found to improve worst-group accuracy even when the held-out set is an imbalanced subset of the training set. We initially hypothesize that this ``unreasonable effectiveness'' of LLR is explained by its ability to mitigate neural collapse through the held-out set, resulting in the implicit bias of gradient descent benefiting robustness. Our empirical investigation does not support this hypothesis. Instead, we present strong evidence for an alternative hypothesis: that the success of LLR is primarily due to better group balance in the held-out set. We conclude by showing how the recent algorithms CB-LLR and AFR perform implicit group-balancing to elicit a robustness improvement.

On the Unreasonable Effectiveness of Last-layer Retraining

TL;DR

ERM often leverages spurious correlations that harm minority-group performance. This work investigates last-layer retraining (LLR) as an efficient remedy, testing whether neural collapse or implicit bias explains its effectiveness. Across four benchmarks, neural collapse did not consistently occur during ERM, and LLR’s gains correlate strongly with improved group balance in the held-out data rather than margin-based dynamics. The findings show that CB-LLR and AFR achieve robust worst-group performance primarily by implicit or explicit group-balancing, guiding practical use of LLR when group annotations are limited and highlighting the importance of data balance in held-out sets.

Abstract

Last-layer retraining (LLR) methods -- wherein the last layer of a neural network is reinitialized and retrained on a held-out set following ERM training -- have garnered interest as an efficient approach to rectify dependence on spurious correlations and improve performance on minority groups. Surprisingly, LLR has been found to improve worst-group accuracy even when the held-out set is an imbalanced subset of the training set. We initially hypothesize that this ``unreasonable effectiveness'' of LLR is explained by its ability to mitigate neural collapse through the held-out set, resulting in the implicit bias of gradient descent benefiting robustness. Our empirical investigation does not support this hypothesis. Instead, we present strong evidence for an alternative hypothesis: that the success of LLR is primarily due to better group balance in the held-out set. We conclude by showing how the recent algorithms CB-LLR and AFR perform implicit group-balancing to elicit a robustness improvement.

Paper Structure

This paper contains 31 sections, 7 equations, 11 figures, 6 tables, 1 algorithm.

Figures (11)

  • Figure 1: Visualization of our initial hypothesis. In Figure 1(a), we represent training set features collapsing to their class means --- dominated by majority group data --- resulting in a biased ERM classifier. In Figure 1(b), we show a biased ERM classifier that performs poorly on the uncollapsed minority group features within the held-out set at the beginning of the second stage of training. In Figure 1(c), we present the final aspect of our hypothesis. During LLR, the features are not collapsed on the held-out set, and so the implicit bias of gradient descent elicits a maximum-margin classifier which is invariant to the spurious feature. Importantly, our empirical investigation does not support this hypothesis. Instead, we find that the success of LLR is primarily explained by better group balance in the held-out set.
  • Figure 2: Collapse of class feature variability occurs after standard ERM training, if at all. We plot a stochastic estimate of the empirical metric of neural collapse $\mathcal{NC}_1$ using Algorithm \ref{['alg:memory_efficient_nc1']} throughout training a ResNet-50 on Waterbirds and CelebA and a BERT-Base on CivilComments and MultiNLI. For Waterbirds and CelebA, $\mathcal{NC}_1$ is computed using $m=10$ random vectors, while for CivilComments and MultiNLI, $\mathcal{NC}_1$ is computed using $m=3$ random vectors. Each plot displays the mean and standard deviation for $\mathcal{NC}_1$ computed across $3$ experimental seeds. We also display the mean $\mathcal{NC}_1$ metric computed on the features of the held-out set at the end of training (EoT).
  • Figure 3: Convergence of LLR to the maximum-margin SVM solution is extremely slow. We plot the mean and standard deviation over $3$ experimental seeds of the directional error $\widehat{Err}$ between the last layer weights of a neural network model and an SVM (both trained on the features of the held-out set). We use a ResNet-50 for Waterbirds and CelebA and a BERT-Base for CivilComments and MultiNLI. Here, $\widehat{Err} := || \frac{\theta_{\text{NN}}}{||\theta_{\text{NN}}||_2} - \frac{\theta_{\text{SVM}}}{||\theta_{\text{SVM}}||_2} ||_2$, where $\theta_{\text{NN}}$ denotes the last layer weights and $\theta_{\text{SVM}}$ denotes the weights of an SVM trained on the held-out set features.
  • Figure 4: LLR performance is determined by held-out set group balance. We compare the test WGA of ERM and LLR models while controlling the group balance of the training and held-out sets. We use ResNet-50 for the vision datasets and BERT-Base for the language datasets, and we plot the mean and standard deviation over 3 experimental seeds. We compute the Pearson correlation coefficient between the ERM test WGA and the LLR test WGA for each dataset; the presented coefficients are averaged over all 5 group ratios. We find that LLR worst-group accuracy correlates strongly with change in group balance; in particular, LLR tends to improve over ERM if and only if the held-out set has better group balance.
  • Figure 5: LLR can recover optimally class-balanced WGA. We compare the test WGA of ERM (trained on 100% of the data) against LLR (trained on a held-out subset with the same group ratio). For both approaches, we evaluate four class-balancing strategies: no class balancing (No CB), upsampling, subsetting, and upweighting (defined in Section \ref{['sec:prelim']}). We use ResNet-50 for the vision datasets and BERT-Base for the language datasets. We plot mean and standard deviation across 3 seeds. Each $x$-axis group corresponds to the class-balancing method used to train the ERM baseline (gray bar). Within each group, we compare this ERM model to the four LLR variants (colored bars), each using the corresponding class-balancing method.
  • ...and 6 more figures