Table of Contents
Fetching ...

Provable Weak-to-Strong Generalization via Benign Overfitting

David X. Wu, Anant Sahai

TL;DR

The paper analyzes weak-to-strong generalization in a stylized, overparameterized Gaussian setting where a weak teacher provides imperfect pseudolabels to train a strong student using minimum $\ell_2$-norm interpolation. It introduces a bi-level, overparameterized covariance model with a subset-relationship between weak and strong features and proves that the strong learner undergoes a sharp asymptotic transition between random guessing and perfect generalization as the amount of weakly labeled data grows, under precise regime conditions. A key technical contribution is a tight lower-tail bound for the maximum of correlated Gaussians, needed to characterize the misclassification probability, along with an extension to multilabel settings via a multilabel-softening approach using logits. The results illuminate when weak supervision can purify representations and drive high generalization in an otherwise benign-overfitting regime, and they connect to practical finetuning and NTK perspectives by focusing on linearized, kernel-like feature maps. Overall, the work provides provable insights into the conditions under which weak-to-strong training succeeds and identifies clear regimes where it fails, with implications for pseudolabeling and knowledge distillation in high-dimensional settings.

Abstract

The classic teacher-student model in machine learning posits that a strong teacher supervises a weak student to improve the student's capabilities. We instead consider the inverted situation, where a weak teacher supervises a strong student with imperfect pseudolabels. This paradigm was recently brought forth by Burns et al.'23 and termed \emph{weak-to-strong generalization}. We theoretically investigate weak-to-strong generalization for binary and multilabel classification in a stylized overparameterized spiked covariance model with Gaussian covariates where the weak teacher's pseudolabels are asymptotically like random guessing. Under these assumptions, we provably identify two asymptotic phases of the strong student's generalization after weak supervision: (1) successful generalization and (2) random guessing. Our techniques should eventually extend to weak-to-strong multiclass classification. Towards doing so, we prove a tight lower tail inequality for the maximum of correlated Gaussians, which may be of independent interest. Understanding the multilabel setting reinforces the value of using logits for weak supervision when they are available.

Provable Weak-to-Strong Generalization via Benign Overfitting

TL;DR

The paper analyzes weak-to-strong generalization in a stylized, overparameterized Gaussian setting where a weak teacher provides imperfect pseudolabels to train a strong student using minimum -norm interpolation. It introduces a bi-level, overparameterized covariance model with a subset-relationship between weak and strong features and proves that the strong learner undergoes a sharp asymptotic transition between random guessing and perfect generalization as the amount of weakly labeled data grows, under precise regime conditions. A key technical contribution is a tight lower-tail bound for the maximum of correlated Gaussians, needed to characterize the misclassification probability, along with an extension to multilabel settings via a multilabel-softening approach using logits. The results illuminate when weak supervision can purify representations and drive high generalization in an otherwise benign-overfitting regime, and they connect to practical finetuning and NTK perspectives by focusing on linearized, kernel-like feature maps. Overall, the work provides provable insights into the conditions under which weak-to-strong training succeeds and identifies clear regimes where it fails, with implications for pseudolabeling and knowledge distillation in high-dimensional settings.

Abstract

The classic teacher-student model in machine learning posits that a strong teacher supervises a weak student to improve the student's capabilities. We instead consider the inverted situation, where a weak teacher supervises a strong student with imperfect pseudolabels. This paradigm was recently brought forth by Burns et al.'23 and termed \emph{weak-to-strong generalization}. We theoretically investigate weak-to-strong generalization for binary and multilabel classification in a stylized overparameterized spiked covariance model with Gaussian covariates where the weak teacher's pseudolabels are asymptotically like random guessing. Under these assumptions, we provably identify two asymptotic phases of the strong student's generalization after weak supervision: (1) successful generalization and (2) random guessing. Our techniques should eventually extend to weak-to-strong multiclass classification. Towards doing so, we prove a tight lower tail inequality for the maximum of correlated Gaussians, which may be of independent interest. Understanding the multilabel setting reinforces the value of using logits for weak supervision when they are available.
Paper Structure (31 sections, 21 theorems, 129 equations, 4 figures)

This paper contains 31 sections, 21 theorems, 129 equations, 4 figures.

Key Result

Theorem 3.1

Suppose the strong features have bi-level covariance $\Sigma = \Sigma(p, q, r)$ (def:bilevel), where $q+r>1$, the true multiclass labels are $1$-sparse (assump:1-sparse), and the number of classes $k = \lfloor n^t \rfloor$ (def:scaling-k). Then the test error for ${\bm{f}}_{\mathsf{strong}}$ MNI-tra where $\tau_{\mathsf{strong}} \triangleq p+1 - 2(q+r)$. Furthermore, for binary classification (i.e

Figures (4)

  • Figure 1: Visualization of subset ensemble (\ref{['assump:w2s-ensemble']}) relating weak and strong features. Notice the decreased favoring for the weak features, and how the weak features are a subset of the strong features in their respective category. Hence, a linear model on the strong features can simulate one on the weak features (\ref{['item:capability']}). Moreover, the label defining directions, represented by the green shaded box, are in the span of both the weak and strong features.
  • Figure 2: (Top): Regime plots for weak-to-strong generalization based on \ref{['thm:achieve']}. The blue region is the successful w2s regime, and the red region is where w2s training fails. The white region corresponds to regimes where either the hypotheses of the theorem fail to hold, or invalid settings of parameters for the bi-level ensemble. (Bottom): Comparison of simulations of MNI test accuracies for two different regimes where $n=50$. Observe how the weak accuracy is close to random guessing and how the weak-to-strong accuracy increases as $m$ increases. As corroborated by the plots, \ref{['thm:achieve']} predicts w2s success in \ref{['fig:w2s-mni-1-main']} and failure in \ref{['fig:w2s-mni-1-fail-main']}.
  • Figure 3: Comparison of test accuracies for four different models using averaging training. The $x$-axis plots $m$, the number of additional labeled datapoints. The models are trained using class averaging, which approximates the behavior of the initial few gradient descent iterations. Note how the weak model has low accuracy, whereas the weak-to-strong model and ground truth have higher accuracies that increase as $m$ increases. The top row \ref{['fig:w2s-avg-1', 'fig:w2s-avg-2']} are in a regime where we predict MNI weak-to-strong generalization to succeed, whereas the bottom row \ref{['fig:w2s-avg-1-fail', 'fig:w2s-avg-2-fail']} depict regimes where we expect MNI weak-to-strong generalization to fail.
  • Figure 4: Comparison of MNI test accuracies for four different models. Observe how the weak-to-strong accuracy increases as $m$ increases. Again, the top row \ref{['fig:w2s-mni-1', 'fig:w2s-mni-2']} are in a regime where we predict MNI weak-to-strong generalization to succeed, whereas the bottom row \ref{['fig:w2s-mni-1-fail', 'fig:w2s-mni-2-fail']} depict regimes where we expect MNI weak-to-strong generalization to fail. The plots corroborate these theoretical predictions.

Theorems & Definitions (47)

  • Remark 2.1
  • Definition 1: Bi-level ensemble
  • Definition 2: Scaling for multiclass
  • Theorem 3.1: Regimes with clean labels
  • Theorem 3.2: Weak-to-strong generalization for subset ensemble
  • Theorem 3.3: Main result
  • Remark 3.4: Bonus desiderata
  • Theorem 3.5: Informal, see \ref{['thm:weak-to-strong-multilabel']}
  • Theorem 3.6: Lower tail for correlated Gaussians
  • Definition 3: Survival and contamination
  • ...and 37 more