Table of Contents
Fetching ...

Ask Your Distribution Shift if Pre-Training is Right for You

Benjamin Cohen-Wang, Joshua Vendrow, Aleksander Madry

TL;DR

This study suggests that, as a rule of thumb, pre-training can help mitigate poor extrapolation but not dataset biases, and explores two of its implications for developing robust models: (1) pre-training and interventions designed to prevent exploiting biases have complementary robustness benefits, and (2) fine-tuning on a (very) small, non-diverse but de-biased dataset can result in significantly more robust models than fine-tuning on a large and diverse but biased dataset

Abstract

Pre-training is a widely used approach to develop models that are robust to distribution shifts. However, in practice, its effectiveness varies: fine-tuning a pre-trained model improves robustness significantly in some cases but not at all in others (compared to training from scratch). In this work, we seek to characterize the failure modes that pre-training can and cannot address. In particular, we focus on two possible failure modes of models under distribution shift: poor extrapolation (e.g., they cannot generalize to a different domain) and biases in the training data (e.g., they rely on spurious features). Our study suggests that, as a rule of thumb, pre-training can help mitigate poor extrapolation but not dataset biases. After providing theoretical motivation and empirical evidence for this finding, we explore two of its implications for developing robust models: (1) pre-training and interventions designed to prevent exploiting biases have complementary robustness benefits, and (2) fine-tuning on a (very) small, non-diverse but de-biased dataset can result in significantly more robust models than fine-tuning on a large and diverse but biased dataset. Code is available at https://github.com/MadryLab/pretraining-distribution-shift-robustness.

Ask Your Distribution Shift if Pre-Training is Right for You

TL;DR

This study suggests that, as a rule of thumb, pre-training can help mitigate poor extrapolation but not dataset biases, and explores two of its implications for developing robust models: (1) pre-training and interventions designed to prevent exploiting biases have complementary robustness benefits, and (2) fine-tuning on a (very) small, non-diverse but de-biased dataset can result in significantly more robust models than fine-tuning on a large and diverse but biased dataset

Abstract

Pre-training is a widely used approach to develop models that are robust to distribution shifts. However, in practice, its effectiveness varies: fine-tuning a pre-trained model improves robustness significantly in some cases but not at all in others (compared to training from scratch). In this work, we seek to characterize the failure modes that pre-training can and cannot address. In particular, we focus on two possible failure modes of models under distribution shift: poor extrapolation (e.g., they cannot generalize to a different domain) and biases in the training data (e.g., they rely on spurious features). Our study suggests that, as a rule of thumb, pre-training can help mitigate poor extrapolation but not dataset biases. After providing theoretical motivation and empirical evidence for this finding, we explore two of its implications for developing robust models: (1) pre-training and interventions designed to prevent exploiting biases have complementary robustness benefits, and (2) fine-tuning on a (very) small, non-diverse but de-biased dataset can result in significantly more robust models than fine-tuning on a large and diverse but biased dataset. Code is available at https://github.com/MadryLab/pretraining-distribution-shift-robustness.
Paper Structure (81 sections, 6 theorems, 22 equations, 18 figures, 1 table)

This paper contains 81 sections, 6 theorems, 22 equations, 18 figures, 1 table.

Key Result

Theorem 4.1

Suppose that we start with initial weights $w_\text{init}\in\mathbb{R}^d$ and run gradient descent to minimize $L_\text{ref}(w)$. With an appropriately chosen learning rate, gradient descent converges to weights $\hat{w}$ that minimize $L_\text{ref}$. Furthermore, $\hat{w}$ can be written as Here, $w^*_\text{ref}$ is a property of the reference dataset $S_\text{ref}$ and lies within the reference

Figures (18)

  • Figure 1: The robustness benefits of pre-training vary. On the ImageNet-V2 distribution shift (left), different pre-trained models all exhibit very little effective robustness (ER), i.e., little improvement over the linear trend of models trained from scratch (see Section \ref{['sec:background']}). Meanwhile, on the ImageNet Sketch distribution shift (right), some of these pre-trained models exhibit substantial effective robustness. We report average effective robustness with a 95% confidence interval in the top left of each plot.
  • Figure 2: Illustration of logistic regression setting. (a) Consider a reference dataset that lies within a subspace $W_\text{ref}$ of $\mathbb{R}^d$. (b) Models trained from different initializations all learn the same (optimal) decision boundary in $W_\text{ref}$, but may behave differently outside of $W_\text{ref}$. (c) Under shifts within $W_\text{ref}$, models with different initializations are equally robust. (d) Under shifts outside of $W_\text{ref}$, initialization can affect robustness.
  • Figure 3: Examples of in-support and out-of-support shifts. One example of an in-support shift (left) is a shift in which the indoor/outdoor frequencies of animal appearances change, but the possible combinations of animal and setting remain the same. An example of an out-of-support shift (right) is a shift from day to night: the nighttime setting is entirely novel.
  • Figure 4: Robustness of pre-trained models to synthetic in-support and out-of-support shifts. For each of two in-support shifts (left) and two out-of-support shifts (right) constructed by modifying ImageNet, the reference and shifted accuracies of models trained from scratch (in blue) are linearly correlated. Pre-trained models exhibit little effective robustness (ER), i.e., little improvement over the linear trend (see Section \ref{['sec:background']}), on the in-support shifts, but have significant effective robustness on the out-of-support shifts (averages with 95% confidence intervals in the top left of each plot). Error bars denote 95% confidence intervals over 4 random trials.
  • Figure 5: Dividing shifts of ImageNet into in-support and out-of-support splits. We divide each of the ImageNet-V2, ImageNet Sketch and ImageNet-R datasets into an in-support split containing examples that look like ImageNet examples and an out-of-support split containing examples that look unlike ImageNet examples (see Appendix \ref{['sec:splitting_appendix']} for a description of the splitting method). We display samples from each split of ImageNet Sketch in Figure \ref{['fig:splitting_example']} and report the average effective robustnesses of pre-trained models in Figure \ref{['fig:splitting_results']}. See Appendix \ref{['sec:splitting_scatterplots']} for scatterplots of reference vs. shifted accuracy.
  • ...and 13 more figures

Theorems & Definitions (10)

  • Theorem 4.1
  • Theorem A.1
  • Lemma A.1
  • proof
  • Lemma A.2
  • proof
  • Lemma A.3
  • proof
  • Lemma A.4
  • proof