Table of Contents
Fetching ...

Sharp analysis of out-of-distribution error for "importance-weighted" estimators in the overparameterized regime

Kuo-Wei Lai, Vidya Muthukumar

TL;DR

This work analyzes generalization under distribution shift for overparameterized linear models trained on a Gaussian Mixture Model with a spurious feature. It derives sharp, matching upper and lower bounds on the worst-group/generalization error for cost-sensitive minimum-norm interpolation (cMNI), showing how the minority/majority counts, total signal strength $R_+$, spurious-signal difference $R_-$, and importance weights $\Delta_\pm$ jointly govern ID and OOD performance. A key finding is a robustness-accuracy tradeoff: increasing the upweighting (larger $\Delta_+$ or $\Delta_-$) improves OOD robustness at the cost of average accuracy, while smaller upweights favor average accuracy but degrade robustness; ridge regularization does not alter the exponent of the bounds. The results leverage benign overfitting techniques and a Woodbury/inverse-Wishart-based analysis to achieve sharp rates, and they apply to both cMNI and cSVM formulations. This work informs practical design of importance weights for improving worst-group generalization under structured distribution shifts.

Abstract

Overparameterized models that achieve zero training error are observed to generalize well on average, but degrade in performance when faced with data that is under-represented in the training sample. In this work, we study an overparameterized Gaussian mixture model imbued with a spurious feature, and sharply analyze the in-distribution and out-of-distribution test error of a cost-sensitive interpolating solution that incorporates "importance weights". Compared to recent work Wang et al. (2021), Behnia et al. (2022), our analysis is sharp with matching upper and lower bounds, and significantly weakens required assumptions on data dimensionality. Our error characterizations also apply to any choice of importance weights and unveil a novel tradeoff between worst-case robustness to distribution shift and average accuracy as a function of the importance weight magnitude.

Sharp analysis of out-of-distribution error for "importance-weighted" estimators in the overparameterized regime

TL;DR

This work analyzes generalization under distribution shift for overparameterized linear models trained on a Gaussian Mixture Model with a spurious feature. It derives sharp, matching upper and lower bounds on the worst-group/generalization error for cost-sensitive minimum-norm interpolation (cMNI), showing how the minority/majority counts, total signal strength , spurious-signal difference , and importance weights jointly govern ID and OOD performance. A key finding is a robustness-accuracy tradeoff: increasing the upweighting (larger or ) improves OOD robustness at the cost of average accuracy, while smaller upweights favor average accuracy but degrade robustness; ridge regularization does not alter the exponent of the bounds. The results leverage benign overfitting techniques and a Woodbury/inverse-Wishart-based analysis to achieve sharp rates, and they apply to both cMNI and cSVM formulations. This work informs practical design of importance weights for improving worst-group generalization under structured distribution shifts.

Abstract

Overparameterized models that achieve zero training error are observed to generalize well on average, but degrade in performance when faced with data that is under-represented in the training sample. In this work, we study an overparameterized Gaussian mixture model imbued with a spurious feature, and sharply analyze the in-distribution and out-of-distribution test error of a cost-sensitive interpolating solution that incorporates "importance weights". Compared to recent work Wang et al. (2021), Behnia et al. (2022), our analysis is sharp with matching upper and lower bounds, and significantly weakens required assumptions on data dimensionality. Our error characterizations also apply to any choice of importance weights and unveil a novel tradeoff between worst-case robustness to distribution shift and average accuracy as a function of the importance weight magnitude.
Paper Structure (26 sections, 12 theorems, 112 equations, 2 figures, 2 tables)

This paper contains 26 sections, 12 theorems, 112 equations, 2 figures, 2 tables.

Key Result

Theorem 1

Under Assumption asm:dataset_gen, the generalization error for each group $b \in \{+1,-1\}$ is upper bounded as with probability at least $1 -\delta$, where we defined $\alpha_\pm \coloneqq \frac{n_\pm/\Delta_\pm^2}{\frac{n_+}{\Delta_+^2} + \frac{n_-}{\Delta_-^2}}=\frac{n_\pm/\Delta_\pm^2}{n_{\Delta}}$ and we have $0 < \alpha_\pm < 1$.

Figures (2)

  • Figure 1: The left panel plots group-wise error as a function of $\Delta_{-}$. We fix $d=10^5$, $n=200$, $n_-=10$, $R_+=d^{0.6}/4$, $\boldsymbol{\mu}_c=\boldsymbol{\mu}_s=\sqrt{R_+/2}\boldsymbol{e}_1$ (as in behnia2022avoid), $\Delta_+=\frac{n_+}{n}$ and make $\Delta_-$ decrease from $\frac{n_+}{n}$ to $\frac{n_-}{n}.$ Observe that the worst-group error decreases when $\Delta_-$ decreases until $\Delta_-=\frac{n_-}{n}$, below which the worst-group becomes the majority group, whose error increases with decreased $\Delta_{-}$. The right panel plots the worst-group error as a function of $n$ fixing $d=2n^2$, $n_-=0.04n$, $R_+=d^{0.6}/4$, $\boldsymbol{\mu}_c=\boldsymbol{\mu}_s=\sqrt{R_+/2}\boldsymbol{e}_1$, $\Delta_\pm=n_\pm/n$. Observe that increasing the ridge regularization parameter $\tau$ improves the worst-group error rate, but only up to a constant factor in the error exponent. These simulations were obtained by averaging over $10$ trials, and error bars are small enough to not be visible.
  • Figure 2: The left panel plots the majority-group and minority-group error as a function of the total signal strength squared ($R_{+}^2$) when the importance weights are set to be equal, i.e. $\Delta_\pm=1$. We see that while $\mathcal{R}_{+1}\left(\hat{\boldsymbol{w}}\right)$ approaches $0$ when $R_+^2=\frac{d}{n}$, $\mathcal{R}_{-1}\left(\hat{\boldsymbol{w}}\right)$ is greater than $0.1$ until $R_+^2=\frac{dn}{n_-^2}$. The right panel also plots group-wise error for the alternative choice $\Delta_\pm=\frac{n_\pm}{n}$. In this case, as expected, both $\mathcal{R}_{+1}\left(\hat{\boldsymbol{w}}\right)$ and $\mathcal{R}_{-1}\left(\hat{\boldsymbol{w}}\right)$ decay to $0$ at similar rates. For both plots, we fix $d=10^5$, $n=200$, $n_-=10$, $\boldsymbol{\mu}_c=\boldsymbol{\mu}_s=\sqrt{R_+/2\boldsymbol{e}_1}$. These simulations were obtained by averaging over $10$ trials, and error bars are small enough not to be visible.

Theorems & Definitions (17)

  • Theorem 1
  • Proposition 1
  • Corollary 1
  • Lemma 1
  • Definition 1
  • Definition 2
  • Lemma 2
  • Lemma 3
  • Lemma 4
  • Lemma 5: muthukumar2021classification, Lemma 21
  • ...and 7 more