Table of Contents
Fetching ...

Towards Robust Out-of-Distribution Generalization Bounds via Sharpness

Yingtian Zou, Kenji Kawaguchi, Yingnan Liu, Jiashuo Liu, Mong-Li Lee, Wynne Hsu

TL;DR

A rigorous connection between sharpness and robustness is given, which gives better OOD guarantees for robust algorithms and provides a theoretical backing for flat minima leads to better OOD generalization.

Abstract

Generalizing to out-of-distribution (OOD) data or unseen domain, termed OOD generalization, still lacks appropriate theoretical guarantees. Canonical OOD bounds focus on different distance measurements between source and target domains but fail to consider the optimization property of the learned model. As empirically shown in recent work, the sharpness of learned minima influences OOD generalization. To bridge this gap between optimization and OOD generalization, we study the effect of sharpness on how a model tolerates data change in domain shift which is usually captured by "robustness" in generalization. In this paper, we give a rigorous connection between sharpness and robustness, which gives better OOD guarantees for robust algorithms. It also provides a theoretical backing for "flat minima leads to better OOD generalization". Overall, we propose a sharpness-based OOD generalization bound by taking robustness into consideration, resulting in a tighter bound than non-robust guarantees. Our findings are supported by the experiments on a ridge regression model, as well as the experiments on deep learning classification tasks.

Towards Robust Out-of-Distribution Generalization Bounds via Sharpness

TL;DR

A rigorous connection between sharpness and robustness is given, which gives better OOD guarantees for robust algorithms and provides a theoretical backing for flat minima leads to better OOD generalization.

Abstract

Generalizing to out-of-distribution (OOD) data or unseen domain, termed OOD generalization, still lacks appropriate theoretical guarantees. Canonical OOD bounds focus on different distance measurements between source and target domains but fail to consider the optimization property of the learned model. As empirically shown in recent work, the sharpness of learned minima influences OOD generalization. To bridge this gap between optimization and OOD generalization, we study the effect of sharpness on how a model tolerates data change in domain shift which is usually captured by "robustness" in generalization. In this paper, we give a rigorous connection between sharpness and robustness, which gives better OOD guarantees for robust algorithms. It also provides a theoretical backing for "flat minima leads to better OOD generalization". Overall, we propose a sharpness-based OOD generalization bound by taking robustness into consideration, resulting in a tighter bound than non-robust guarantees. Our findings are supported by the experiments on a ridge regression model, as well as the experiments on deep learning classification tasks.
Paper Structure (44 sections, 20 theorems, 165 equations, 9 figures, 1 table, 1 algorithm)

This paper contains 44 sections, 20 theorems, 165 equations, 9 figures, 1 table, 1 algorithm.

Key Result

Proposition 2.1

With hypothesis class $\mathcal{F}$ and pseudo dimension $\operatorname{Pdim}(\mathcal{F})=d'$, unlabeled empirical datasets from source and target distribution $\widehat{\mathcal{D}}_S$ and $\widehat{\mathcal{D}}_T$ with size $n$ each, then with probability at least $1-\delta$, for all $f \in \math where $d_{\mathcal{F} \Delta \mathcal{F}}(\widehat{\mathcal{D}}_T; \widehat{\mathcal{D}}_{S}):= 2

Figures (9)

  • Figure 1: An example of a target distribution (red) directly translated from a source distribution (blue). The 1D density reflects the marginal distribution. Unlike existing works (left), we divide the distributions into disjoint partitions as a small change in distribution for a robust model is negligible (right). The sharpness of the model will decide the tolerance of change thus affecting the partitions. If two sub-distributions $S, T$ have small shifts such that they fall into the same partition (red grid), their distance measure $d^\prime(S, T)$ by considering robustness will be zero.
  • Figure 2: OOD test losses increase along distributional shifting. The X-axis is the shifting angle $\alpha$ and the Y-axis is the test loss of the model which is trained on distribution $\alpha=0$. Lines are average test losses and shadows are variances of 10 trials. Larger regularization $\beta$ (darker color) causes a lower increase in test loss but smaller sharpness.
  • Figure 3: The relationship between out-of-domain test accuracy and model sharpness on RotatedMNIST dataset. Here we show 4 different OOD environments: $15^{\circ}, 30^{\circ}, 45^{\circ}, 60^{\circ}$ rotation as the OOD test set respectively. Each marker denotes a minimum of an algorithm with a specific seed. The marker style means the models trained in the same environment.
  • Figure 4: Spurious feature synthetic experiment. Each dot represents a trained model. The dash curves are the smoothed function fit by the test data points. The baseline is \ref{['prop:zhaohan']}. (a),(d): the generalization error of the logistic regression models with increasing the model size/correlation probability. (b): concentration error term in domain shift bound. (e): comparison of distribution distance bounds. (c),(f): comparisons of generalization bounds. Note that model size $>500$ is the overparameterized regime. The further the correlation probability is from $0.5$, the greater the distributional shift is.
  • Figure 5: The relationship between out-of-distribution (OOD) test accuracy on the test environment and model sharpness (of last FC layer) on the Wilds-Camelyon17 dataset. Each marker denotes a model trained using ERM with different seed and hyperparameters.
  • ...and 4 more figures

Theorems & Definitions (43)

  • Proposition 2.1: Zhao et al. zhao2018adversarial Theorem 2 & 3.2
  • Definition 2.2: Robustness, xu2012robustness
  • Theorem 3.1
  • Remark
  • Corollary 3.2
  • Definition 3.3: Sharpness, petzka2021relative
  • Theorem 3.4
  • Remark
  • Corollary 3.5
  • Example 3.6
  • ...and 33 more