Table of Contents
Fetching ...

Group-robust Sample Reweighting for Subpopulation Shifts via Influence Functions

Rui Qiao, Zhaoxuan Wu, Jingtan Wang, Pang Wei Koh, Bryan Kian Hsiang Low

TL;DR

The paper tackles subpopulation shifts that undermine worst-group performance by proposing Group-robust Sample Reweighting (GSR), a two-stage method that uses high-quality group labels as a target to reweight group-unlabeled data. By combining Last-layer Retraining with an influence-function–driven implicit differentiation framework, GSR computes Hessian-informed sample weights efficiently without unrolling the full training trajectory. The approach yields state-of-the-art or competitive worst-group accuracy on several benchmarks (notably MultiNLI and CivilComments) and demonstrates robustness to label noise, while offering insights into weight dynamics and data cleaning effects. This provides a practical, scalable route to improve subpopulation robustness under limited annotation budgets, with clear avenues for future work in representation learning and fuller model optimization.

Abstract

Machine learning models often have uneven performance among subpopulations (a.k.a., groups) in the data distributions. This poses a significant challenge for the models to generalize when the proportions of the groups shift during deployment. To improve robustness to such shifts, existing approaches have developed strategies that train models or perform hyperparameter tuning using the group-labeled data to minimize the worst-case loss over groups. However, a non-trivial amount of high-quality labels is often required to obtain noticeable improvements. Given the costliness of the labels, we propose to adopt a different paradigm to enhance group label efficiency: utilizing the group-labeled data as a target set to optimize the weights of other group-unlabeled data. We introduce Group-robust Sample Reweighting (GSR), a two-stage approach that first learns the representations from group-unlabeled data, and then tinkers the model by iteratively retraining its last layer on the reweighted data using influence functions. Our GSR is theoretically sound, practically lightweight, and effective in improving the robustness to subpopulation shifts. In particular, GSR outperforms the previous state-of-the-art approaches that require the same amount or even more group labels.

Group-robust Sample Reweighting for Subpopulation Shifts via Influence Functions

TL;DR

The paper tackles subpopulation shifts that undermine worst-group performance by proposing Group-robust Sample Reweighting (GSR), a two-stage method that uses high-quality group labels as a target to reweight group-unlabeled data. By combining Last-layer Retraining with an influence-function–driven implicit differentiation framework, GSR computes Hessian-informed sample weights efficiently without unrolling the full training trajectory. The approach yields state-of-the-art or competitive worst-group accuracy on several benchmarks (notably MultiNLI and CivilComments) and demonstrates robustness to label noise, while offering insights into weight dynamics and data cleaning effects. This provides a practical, scalable route to improve subpopulation robustness under limited annotation budgets, with clear avenues for future work in representation learning and fuller model optimization.

Abstract

Machine learning models often have uneven performance among subpopulations (a.k.a., groups) in the data distributions. This poses a significant challenge for the models to generalize when the proportions of the groups shift during deployment. To improve robustness to such shifts, existing approaches have developed strategies that train models or perform hyperparameter tuning using the group-labeled data to minimize the worst-case loss over groups. However, a non-trivial amount of high-quality labels is often required to obtain noticeable improvements. Given the costliness of the labels, we propose to adopt a different paradigm to enhance group label efficiency: utilizing the group-labeled data as a target set to optimize the weights of other group-unlabeled data. We introduce Group-robust Sample Reweighting (GSR), a two-stage approach that first learns the representations from group-unlabeled data, and then tinkers the model by iteratively retraining its last layer on the reweighted data using influence functions. Our GSR is theoretically sound, practically lightweight, and effective in improving the robustness to subpopulation shifts. In particular, GSR outperforms the previous state-of-the-art approaches that require the same amount or even more group labels.

Paper Structure

This paper contains 31 sections, 2 theorems, 23 equations, 5 figures, 10 tables, 1 algorithm.

Key Result

Proposition A.4

The first-order derivative matrix $\nabla_x \sigma(x)$ of the softmax function is positive semidefinite.

Figures (5)

  • Figure 1: The change in the sum of sample weights across different groups throughout the training. The minority groups are upweighted and the majority groups are generally downweighted. However, different groups do not have equal sums of weights.
  • Figure 2: The distribution of sample weights for each group that are used to train the best models. The minority groups have the weight distribution stretched out towards high values, while the majority-group weights are generally skewed towards 0.
  • Figure 3: We illustrate the selected images with their class label and background label from the held-out split in Waterbirds according to the sample weights. In \ref{['fig:wb-high']}, the top-5 most weighted instances are all from the minority groups with differing bird types and backgrounds. As highlighted in the red box in (a), the background of the waterbird is ambiguous as it blends features of both land and water and hence should be categorized as a minority. GSR correctly identifies it despite the suboptimal annotation. In \ref{['fig:wb-low']}, the top-5 least weighted instances are all from the majority groups where the background is spuriously correlated with the bird type.
  • Figure 4: In-depth study of our algorithm. In \ref{['fig:noise-robust']}, the worst-group test accuracy degrades slightly even when up to 40% of the held-out set labels are corrupted. In \ref{['fig:clean-weights']}, most of the uncorrupted minority samples received higher weight assignments than non-minority examples. In contrast, in \ref{['fig:noise-weights']}, the corrupted minority instances are correctly assigned with close-to-0 weights. \ref{['fig:val-vs-test']} shows the relationship between validation and test worst-group accuracy on the CelebA dataset. It is important to have separate target and validation sets. Otherwise, overfitting can easily occur.
  • Figure 5: An illustration of the Waterbirds dataset. A spurious correlation exists between the bird class and the background. The majority groups are highlighted in orange. The minority groups are in gray.

Theorems & Definitions (8)

  • Definition A.1: $\mu$-Strong convexity boyd2004convex
  • Definition A.2: One-hot Vector
  • Definition A.3: Softmax
  • Proposition A.4
  • proof
  • Definition A.5: Cross-entropy
  • Theorem A.6: Strong convexity of cross-entropy loss with $\ell_2$-regularization
  • proof