Table of Contents
Fetching ...

Theoretical Guarantees of Data Augmented Last Layer Retraining Methods

Monica Welfert, Nathan Stromberg, Lalitha Sankar

TL;DR

The paper studies data augmentation strategies for last-layer retraining to maximize worst-group accuracy, modeling latent representations as Gaussian mixtures across subpopulations. It proves that downsampling and upweighting are risk-equivalent for any loss, derives closed-form optimal parameters under Gaussian assumptions, and characterizes the worst-group error (WGE) under various augmentation schemes. Under an orthogonality condition, SRM incurs higher WGE than DS/UW/MU, while MU shares the same WGE as DS and UW; finite-sample rates are established for ERM, DS, UW, and MU. Empirical results on synthetic data and public datasets (CMNIST, CelebA, Waterbirds) corroborate the theory, with DS/UW/MU outperforming ERM and displaying similar WGE in practice.

Abstract

Ensuring fair predictions across many distinct subpopulations in the training data can be prohibitive for large models. Recently, simple linear last layer retraining strategies, in combination with data augmentation methods such as upweighting, downsampling and mixup, have been shown to achieve state-of-the-art performance for worst-group accuracy, which quantifies accuracy for the least prevalent subpopulation. For linear last layer retraining and the abovementioned augmentations, we present the optimal worst-group accuracy when modeling the distribution of the latent representations (input to the last layer) as Gaussian for each subpopulation. We evaluate and verify our results for both synthetic and large publicly available datasets.

Theoretical Guarantees of Data Augmented Last Layer Retraining Methods

TL;DR

The paper studies data augmentation strategies for last-layer retraining to maximize worst-group accuracy, modeling latent representations as Gaussian mixtures across subpopulations. It proves that downsampling and upweighting are risk-equivalent for any loss, derives closed-form optimal parameters under Gaussian assumptions, and characterizes the worst-group error (WGE) under various augmentation schemes. Under an orthogonality condition, SRM incurs higher WGE than DS/UW/MU, while MU shares the same WGE as DS and UW; finite-sample rates are established for ERM, DS, UW, and MU. Empirical results on synthetic data and public datasets (CMNIST, CelebA, Waterbirds) corroborate the theory, with DS/UW/MU outperforming ERM and displaying similar WGE in practice.

Abstract

Ensuring fair predictions across many distinct subpopulations in the training data can be prohibitive for large models. Recently, simple linear last layer retraining strategies, in combination with data augmentation methods such as upweighting, downsampling and mixup, have been shown to achieve state-of-the-art performance for worst-group accuracy, which quantifies accuracy for the least prevalent subpopulation. For linear last layer retraining and the abovementioned augmentations, we present the optimal worst-group accuracy when modeling the distribution of the latent representations (input to the last layer) as Gaussian for each subpopulation. We evaluate and verify our results for both synthetic and large publicly available datasets.
Paper Structure (16 sections, 7 theorems, 66 equations, 6 figures, 1 table)

This paper contains 16 sections, 7 theorems, 66 equations, 6 figures, 1 table.

Key Result

Theorem 1

For any given $P_{X,Y,D}$ and loss $\ell$, the objectives in eq:gen-opt when modified appropriately for DS and UW are the same. Therefore, if a minimizer exists for one of them, then the minimizer of the other is the same, i.e., $\theta^*_\text{DS}=\theta^*_\text{UW}$.

Figures (6)

  • Figure 1: $\Delta_C$ and $\Delta_D$ are shown as line segments between group means overlaid on data sampled from Gaussian mixtures satisfying \ref{['as:latent_normal', 'as:equal_priors', 'as:mean_difference', 'as:orthogonality']}.
  • Figure 2: The optimal prediction planes for DS, UW, MU, and SRM are shown overlaid on data sampled from Gaussian mixtures satisfying \ref{['as:latent_normal', 'as:equal_priors', 'as:mean_difference', 'as:orthogonality']}. The SRM model largely ignores the minority group for each class.
  • Figure 3: WGE for UW, DS, MU and ERM for the data in \ref{['fig:optimal_lines']}. As the number of samples $n$ increases, UW, DS, and MU perform better than ERM.
  • Figure 4: Zoomed in version of \ref{['fig:wge_gaussian']} where we see the differences between data augmentation methods, especially for small $n$.
  • Figure 5: Mean squared error of the estimated weights from data as compared to the expected weights. We see that each method converges quickly to the expected weights as a function of $n$.
  • ...and 1 more figures

Theorems & Definitions (15)

  • Theorem 1
  • proof : Proof sketch
  • Remark 1
  • Remark 2
  • Proposition 1
  • proof : Proof sketch
  • Corollary 1
  • Theorem 2
  • proof : Proof sketch
  • Theorem 3
  • ...and 5 more