Table of Contents
Fetching ...

Residual-as-Teacher: Mitigating Bias Propagation in Student--Teacher Estimation

Kakei Yamamoto, Martin J. Wainwright

Abstract

We study statistical estimation in a student--teacher setting, where predictions from a pre-trained teacher are used to guide a student model. A standard approach is to train the student to directly match the teacher's outputs, which we refer to as student soft matching (SM). This approach directly propagates any systematic bias or mis-specification present in the teacher, thereby degrading the student's predictions. We propose and analyze an alternative scheme, known as residual-as-teacher (RaT), in which the teacher is used to estimate residuals in the student's predictions. Our analysis shows how the student can thereby emulate a proximal gradient scheme for solving an oracle optimization problem, and this provably reduces the effect of teacher bias. For general student--teacher pairs, we establish non-asymptotic excess risk bounds for any RaT fixed point, along with convergence guarantees for the student-teacher iterative scheme. For kernel-based student--teacher pairs, we prove a sharp separation: the RaT method achieves the minimax-optimal rate, while the SM method incurs constant prediction error for any sample size. Experiments on both synthetic data and ImageNette classification under covariate shift corroborate our theoretical findings.

Residual-as-Teacher: Mitigating Bias Propagation in Student--Teacher Estimation

Abstract

We study statistical estimation in a student--teacher setting, where predictions from a pre-trained teacher are used to guide a student model. A standard approach is to train the student to directly match the teacher's outputs, which we refer to as student soft matching (SM). This approach directly propagates any systematic bias or mis-specification present in the teacher, thereby degrading the student's predictions. We propose and analyze an alternative scheme, known as residual-as-teacher (RaT), in which the teacher is used to estimate residuals in the student's predictions. Our analysis shows how the student can thereby emulate a proximal gradient scheme for solving an oracle optimization problem, and this provably reduces the effect of teacher bias. For general student--teacher pairs, we establish non-asymptotic excess risk bounds for any RaT fixed point, along with convergence guarantees for the student-teacher iterative scheme. For kernel-based student--teacher pairs, we prove a sharp separation: the RaT method achieves the minimax-optimal rate, while the SM method incurs constant prediction error for any sample size. Experiments on both synthetic data and ImageNette classification under covariate shift corroborate our theoretical findings.

Paper Structure

This paper contains 81 sections, 7 theorems, 178 equations, 8 figures.

Key Result

Theorem 1

Suppose that oracle risk function $\bar{R}(f) = \tfrac{1}{m} \{ \bar{L}_m(f) + \operatorname{Pen}(f) \}$ is convex in the fitted values, and let $f^\dagger$ be a minimizer. Then for any RaT fixed point $\widehat{f}_{\scaleto{\texttt{RaT}\xspace}{4pt}}$:

Figures (8)

  • Figure 1: Comparison of the RaT and SM estimates when the student class $\mathscr{F}$ is a two-layer neural network with $128$ hidden units. Shown are results for four different covariate shifts (rows 1--4), and three different teacher models (columns 2--4). Each row corresponds to a different source–target distribution pair, shown in the leftmost column via their marginal densities. The remaining columns report results for three teacher classes: Boosting, Kernel ridge regression (KRR), and ReLU neural network fits. In each panel, the true regression function is shown as a dotted orange line, while the estimates obtained by SM and RaT are shown in red and blue, respectively. Gray points indicate source samples. Intermediate RaT iterates are shown as faint curves to illustrate the refinement process.
  • Figure 2: Convergence behavior of the Picard iteration \ref{['EqnPicard']} when using a neural network teacher to construct the gradient estimate $\widehat{\mathcal{G}}$, and for multinomial classifier over $K = 10$ classes (see \ref{['SecNumImage']} for more details.) Each plot shows the operator defect norm $\|\widehat{\mathcal{D}}_\eta(f^{k})\|$ versus the iteration number $k$. The three different curves correspond to three-layer neural-net teachers with three different architectures: hidden units $(h_1, h_2)$ are marked in the labels. Solid lines correspond to the mean defect over $T = 100$ random trials, with the shaded areas showing 95% confidence intervals. Stepsizes $\eta$ are marked below each plot.
  • Figure 3: Mean-squared errors of RaT and SM under Gaussian covariate shift using a Hermite feature model, target distribution $\mathbb{Q}=N(0,1)$, and source distributions $\mathbb{P}=N(0,\sigma_P^2)$ with $\sigma_P\in\{0.9,1.0,1.1\}$. Solid curves show the median MSE with interquartile bands over repetitions, and dashed lines indicate the predicted scaling laws. Consistent with \ref{['ThmSeparation']}, RaT achieves the rate $n^{-2\beta/(2\alpha+1)}$, whereas SM exhibits a non-vanishing error floor due to distribution shift.
  • Figure 4: Mean-squared errors of the RaT and SM estimators for Laplace kernel ridge regression, using target distribution $\mathbb{Q}=\mathrm{Beta}(1,1)$ and source distributions $\mathbb{P}=\mathrm{Beta}(\alpha,1)$ with $\alpha\in{0.2, 1.0, 2.0}$. The teacher and student use Laplace kernels with differing bandwidths, so that the teacher is mis-specified relative to the student. Solid curves show the median MSE with interquartile bands over repeated trials. Each curve is normalized by its value at the smallest sample size. The dashed line represents a least-squares linear fit on a log--log scale to the RaT curve.
  • Figure 5: Mean-squared errors (MSE) of RaT and SM estimators for neural network students with a regression-stump teacher, and Beta covariate shift. (a) Target test MSE versus source sample size on a log-log scale. Solid lines show the median over repeated runs, and shaded regions indicate the interquartile range. The dashed line is a least-squares linear fit in log–log coordinates to the RaT median curve. (b) Function estimates for a single run, with ground truth function (black dashed), source samples (gray points), SM teacher (dark red dotted), and final function estimates (solid blue for RaT, solid red for SM). Faint blue curves correspond to intermediate RaT iterates. The reported MSE values are computed on held-out target data.
  • ...and 3 more figures

Theorems & Definitions (7)

  • Theorem 1
  • Proposition 1
  • Corollary 1
  • Proposition 2: Exact MSEs for SM/RaT
  • Theorem 2: Separation result for SM/RaT, with a matching lower bound
  • Theorem 3: Computational guarantees for RaT
  • Lemma 1: RaT bias and variance bounds