Table of Contents
Fetching ...

Improved Scaling Laws via Weak-to-Strong Generalization in Random Feature Ridge Regression

Diyuan Wu, Lehan Chen, Theodor Misiakiewicz, Marco Mondelli

TL;DR

This paper derives a deterministic equivalent for the excess test error of the student trained on labels obtained via the teacher and identifies regimes in which the scaling law of the student improves upon that of the teacher, unveiling that the improvement can be achieved both in bias-dominated and variance-dominated settings.

Abstract

It is increasingly common in machine learning to use learned models to label data and then employ such data to train more capable models. The phenomenon of weak-to-strong generalization exemplifies the advantage of this two-stage procedure: a strong student is trained on imperfect labels obtained from a weak teacher, and yet the strong student outperforms the weak teacher. In this paper, we show that the potential improvement is substantial, in the sense that it affects the scaling law followed by the test error. Specifically, we consider students and teachers trained via random feature ridge regression (RFRR). Our main technical contribution is to derive a deterministic equivalent for the excess test error of the student trained on labels obtained via the teacher. Via this deterministic equivalent, we then identify regimes in which the scaling law of the student improves upon that of the teacher, unveiling that the improvement can be achieved both in bias-dominated and variance-dominated settings. Strikingly, the student may attain the minimax optimal rate regardless of the scaling law of the teacher -- in fact, when the test error of the teacher does not even decay with the sample size.

Improved Scaling Laws via Weak-to-Strong Generalization in Random Feature Ridge Regression

TL;DR

This paper derives a deterministic equivalent for the excess test error of the student trained on labels obtained via the teacher and identifies regimes in which the scaling law of the student improves upon that of the teacher, unveiling that the improvement can be achieved both in bias-dominated and variance-dominated settings.

Abstract

It is increasingly common in machine learning to use learned models to label data and then employ such data to train more capable models. The phenomenon of weak-to-strong generalization exemplifies the advantage of this two-stage procedure: a strong student is trained on imperfect labels obtained from a weak teacher, and yet the strong student outperforms the weak teacher. In this paper, we show that the potential improvement is substantial, in the sense that it affects the scaling law followed by the test error. Specifically, we consider students and teachers trained via random feature ridge regression (RFRR). Our main technical contribution is to derive a deterministic equivalent for the excess test error of the student trained on labels obtained via the teacher. Via this deterministic equivalent, we then identify regimes in which the scaling law of the student improves upon that of the teacher, unveiling that the improvement can be achieved both in bias-dominated and variance-dominated settings. Strikingly, the student may attain the minimax optimal rate regardless of the scaling law of the teacher -- in fact, when the test error of the teacher does not even decay with the sample size.
Paper Structure (60 sections, 16 theorems, 360 equations, 2 figures)

This paper contains 60 sections, 16 theorems, 360 equations, 2 figures.

Key Result

Theorem 1

defilippis2024dimension There exist absolute constants $C_0,C_1>0$ such that the following holds. Under Assumptions ass:concentration-features and ass:tech-assumption, for any $D,K>0$, there exist constants $\eta_*\in(0,1/2)$ and $C_*>0$ depending only on $D,K$ and the constants in the assumptions, then with probability at least $1 - \min(n_\mathsf{t}, p_\mathsf{t})^{-D}$, where the approximatio

Figures (2)

  • Figure 1: Excess test errors of various students trained on teacher labels, together with the corresponding deterministic equivalents (Theorem \ref{['thm:additive_err']}), as a function of the number of student features $p_\mathsf{s}$. The teacher and the student are random feature models with the same number of features ($p_\mathsf{t}=p_\mathsf{s}$), trained with the same sample size ($n_\mathsf{t}=n_\mathsf{s}$) and with the teacher regularization being half of that of the student ($\lambda_\mathsf{t} = \lambda_\mathsf{s}/2$). In the top two plots, the data is given by ${\boldsymbol x}_i \sim \mathcal{N}(0,I_d)$ and $y_i = \text{erf}(\langle {\boldsymbol x}_i, {\boldsymbol w}_*\rangle)+ \varepsilon_i$, with ${\boldsymbol w}_* \sim {\rm Unif} (\mathbb{S}^{d-1})$, and the random feature model is $\varphi({\boldsymbol x};{\boldsymbol w})=\tanh(\langle {\boldsymbol x}, {\boldsymbol w}\rangle)$, ${\boldsymbol w} \sim {\rm Unif} (\mathbb{S}^{d-1})$; in the bottom two plots, the data is obtained from the MNIST dataset and the random feature model is $\varphi({\boldsymbol x};{\boldsymbol w})=\text{erf}(\langle {\boldsymbol x}, {\boldsymbol w}\rangle)$, ${\boldsymbol w} \sim {\rm Unif} (\mathbb{S}^{d-1})$. In the left figures, we fix the sample size $n_\mathsf{s}$ and consider three values of the regularization $\lambda_\mathsf{s}$ (in three different colors); in the right figures, we fix the regularization $\lambda_\mathsf{s}$ and consider three values of the sample size $n_\mathsf{s}$ (in three different colors). We run 10 independent experiments, reporting the average and the confidence interval at $1$ standard deviation. The circles represent the test errors obtained experimentally and they match well the continuous curves which represent the deterministic equivalents of Definition \ref{['def:st']}.
  • Figure 2: Excess test errors of various students and teachers, as a function of the number of ground-truth samples $n_\mathsf{t}$. We consider random feature ridge regression with weak-to-strong training, under the source and capacity conditions in \ref{['eq:source-capa']} and with the hyperparameter scaling in \ref{['eq:relative-scale']}. Each plot corresponds to a different choice of $(\alpha,r,\gamma_{p_\mathsf{t}},\gamma_{\lambda_\mathsf{t}})$. Points are empirical experiments, solid lines are deterministic equivalent predictions (Theorems \ref{['thm:teacher-error']} and \ref{['thm:additive_err']}), dotted lines are theoretical decay rates (Theorems \ref{['thm:scaling-teacher']} and \ref{['thm:scaling-st']}), and the grey dashed line is the minimax rate. The teacher test error is in blue, while the other colors correspond to student test errors for different choices of $(\gamma_{n_\mathsf{s}},\gamma_{p_\mathsf{s}}, \gamma_{\lambda_\mathsf{s}})$.

Theorems & Definitions (32)

  • Definition 1: Teacher deterministic equivalent
  • Theorem 1
  • Definition 2: Student deterministic equivalent
  • Theorem 2
  • Theorem 3
  • Theorem 4: Student scaling laws
  • Corollary 1
  • Corollary 2
  • Corollary 3
  • Definition 3: Intrinsic dimension
  • ...and 22 more