Table of Contents
Fetching ...

Changing the Training Data Distribution to Reduce Simplicity Bias Improves In-distribution Generalization

Dang Nguyen, Paymon Haddad, Eric Gan, Baharan Mirzasoleiman

TL;DR

This work rigorously proves that SAM learns different features more uniformly, particularly in early epochs, and empirically shows that USEFUL effectively improves the generalization performance on the original data distribution when training with various gradient methods, including (S)GD and SAM.

Abstract

Can we modify the training data distribution to encourage the underlying optimization method toward finding solutions with superior generalization performance on in-distribution data? In this work, we approach this question for the first time by comparing the inductive bias of gradient descent (GD) with that of sharpness-aware minimization (SAM). By studying a two-layer CNN, we rigorously prove that SAM learns different features more uniformly, particularly in early epochs. That is, SAM is less susceptible to simplicity bias compared to GD. We also show that examples containing features that are learned early are separable from the rest based on the model's output. Based on this observation, we propose a method that (i) clusters examples based on the network output early in training, (ii) identifies a cluster of examples with similar network output, and (iii) upsamples the rest of examples only once to alleviate the simplicity bias. We show empirically that USEFUL effectively improves the generalization performance on the original data distribution when training with various gradient methods, including (S)GD and SAM. Notably, we demonstrate that our method can be combined with SAM variants and existing data augmentation strategies to achieve, to the best of our knowledge, state-of-the-art performance for training ResNet18 on CIFAR10, STL10, CINIC10, Tiny-ImageNet; ResNet34 on CIFAR100; and VGG19 and DenseNet121 on CIFAR10.

Changing the Training Data Distribution to Reduce Simplicity Bias Improves In-distribution Generalization

TL;DR

This work rigorously proves that SAM learns different features more uniformly, particularly in early epochs, and empirically shows that USEFUL effectively improves the generalization performance on the original data distribution when training with various gradient methods, including (S)GD and SAM.

Abstract

Can we modify the training data distribution to encourage the underlying optimization method toward finding solutions with superior generalization performance on in-distribution data? In this work, we approach this question for the first time by comparing the inductive bias of gradient descent (GD) with that of sharpness-aware minimization (SAM). By studying a two-layer CNN, we rigorously prove that SAM learns different features more uniformly, particularly in early epochs. That is, SAM is less susceptible to simplicity bias compared to GD. We also show that examples containing features that are learned early are separable from the rest based on the model's output. Based on this observation, we propose a method that (i) clusters examples based on the network output early in training, (ii) identifies a cluster of examples with similar network output, and (iii) upsamples the rest of examples only once to alleviate the simplicity bias. We show empirically that USEFUL effectively improves the generalization performance on the original data distribution when training with various gradient methods, including (S)GD and SAM. Notably, we demonstrate that our method can be combined with SAM variants and existing data augmentation strategies to achieve, to the best of our knowledge, state-of-the-art performance for training ResNet18 on CIFAR10, STL10, CINIC10, Tiny-ImageNet; ResNet34 on CIFAR100; and VGG19 and DenseNet121 on CIFAR10.
Paper Structure (31 sections, 31 theorems, 100 equations, 15 figures, 9 tables, 1 algorithm)

This paper contains 31 sections, 31 theorems, 100 equations, 15 figures, 9 tables, 1 algorithm.

Key Result

Theorem 3.2

Consider training a two-layer nonlinear CNN model initialized with $\pmb{W}^{(0)} \sim {\mathcal{N}}(0, \sigma_0^2)$ on the training dataset $D = \{ (\pmb{x}_i, y_i) \}_{i=1}^N$ with distribution ${\mathcal{D}}(\beta_e, \beta_d, \alpha)$ with $\alpha^{1/3} \beta_e > \beta_d$. For a small-enough lear after training for $T_{\text{GD}}$ iterations, w.h.p., the model: (1) learns the fast-learnable fe

Figures (15)

  • Figure 1: Examples of slow-learnable (top) and fast-learnable (bottom) in CIFAR-10 found by our method. Examples in the top row (slow-learnable ) are harder to identify visually and look more ambiguous (part of the object is in the image or the object is smaller and the area associated with the background is larger). In contrast, examples in the bottom row (fast-learnable ) are not ambiguous and are clear representatives of their corresponding class, hence are very easy to visually classify (the entire object is in the image and the area associated with the background is small).
  • Figure 2: TSNE visualization of output vectors. (left) ResNet18/CIFAR-10 at epoch 8. (right) CNN/toy data generated based on Definition \ref{['def:data_distribution']} with $\beta_d = 0.2, \beta_e = 1, \alpha = 0.9$, iteration 200.
  • Figure 3: GD (blue) vs. SAM (orange) on toy datasets. Data is generated based on Definition \ref{['def:data_distribution']} with different $\beta_d$ and fixed $\beta_e = 1,$$\alpha = 0.9$. $\cdot\cdot$ and $--$ lines denote the alignment (i.e., inner product) of fast-learnable ($\pmb{v}_e$) and slow-learnable ($\pmb{v}_d$) features with the model weight ($\pmb{w}_j^{(t)}$). (a), (b) GD and SAM first learn the fast-learnable feature. Notably, GD learns the fast-learnable feature very early. (c) Test accuracy of GD & SAM improves by increasing the strength of the slow-learnable feature.
  • Figure 4: Test classification error of ResNet18 on CIFAR10, STL10, TinyImageNet and ResNet34 on CIFAR100. The numbers below bars indicate the approximate training cost and the tick on top shows the std over three runs. USEFUL enhances the performance of SGD and SAM on all 5 datasets. TrivialAugment (TA) further boosts SAM's performance (except for CINIC10). Remarkably, USEFUL consistently boosts the performance across all scenarios and achieves (to our knowledge) SOTA performance for ResNet18 and ResNet34 on the selected datasets when combined with SAM and TA.
  • Figure 5: Test classification errors of different architectures on CIFAR10. USEFUL improves the performance of SGD and SAM when training different architectures. TrivialAugment (TA) further boosts SAM's capabilities. The results for 3-layer MLP can be found in Figure \ref{['fig:vary_architectures_app']}.
  • ...and 10 more figures

Theorems & Definitions (52)

  • Definition 3.1: Data distribution
  • Theorem 3.2: GD Feature Learning
  • Theorem 3.3: SAM Feature Learning
  • Theorem 3.4: SAM learns features more evenly than GD
  • Theorem 3.5: One-shot upsampling
  • Lemma A.2: Gradient
  • proof
  • Lemma A.3: Fast-learnable feature update.
  • proof
  • Lemma A.4: Slow-learnable feature update.
  • ...and 42 more