Table of Contents
Fetching ...

On student-teacher deviations in distillation: does it pay to disobey?

Vaishnavh Nagarajan, Aditya Krishna Menon, Srinadh Bhojanapalli, Hossein Mobahi, Sanjiv Kumar

TL;DR

The paper investigates why knowledge distillation can produce a student that deviates from the teacher yet generalizes better. It identifies two core phenomena—exaggerated confidence in the student's predictions and an exaggerated implicit bias of gradient descent toward top data eigendirections—and shows these arise from a regularization effect of distillation. A formal result in a linear-GD setting demonstrates bias amplification toward top eigen-directions, and extensive neural-network experiments with cross-entropy validate the phenomenon and its link to improved generalization, while also highlighting conditions under which distillation can hurt. The work provides a cohesive theory connecting gradient-descent dynamics to practical distillation behavior, offering guidance on loss-switching and when deliberate deviations may be advantageous.

Abstract

Knowledge distillation (KD) has been widely used to improve the test accuracy of a "student" network, by training it to mimic the soft probabilities of a trained "teacher" network. Yet, it has been shown in recent work that, despite being trained to fit the teacher's probabilities, the student may not only significantly deviate from the teacher probabilities, but may also outdo than the teacher in performance. Our work aims to reconcile this seemingly paradoxical observation. Specifically, we characterize the precise nature of the student-teacher deviations, and argue how they can co-occur with better generalization. First, through experiments on image and language data, we identify that these probability deviations correspond to the student systematically exaggerating the confidence levels of the teacher. Next, we theoretically and empirically establish another form of exaggeration in some simple settings: KD exaggerates the implicit bias of gradient descent in converging faster along the top eigendirections of the data. Finally, we tie these two observations together: we demonstrate that the exaggerated bias of KD can simultaneously result in both (a) the exaggeration of confidence and (b) the improved generalization of the student, thus offering a resolution to the apparent paradox. Our analysis brings existing theory and practice closer by considering the role of gradient descent in KD and by demonstrating the exaggerated bias effect in both theoretical and empirical settings.

On student-teacher deviations in distillation: does it pay to disobey?

TL;DR

The paper investigates why knowledge distillation can produce a student that deviates from the teacher yet generalizes better. It identifies two core phenomena—exaggerated confidence in the student's predictions and an exaggerated implicit bias of gradient descent toward top data eigendirections—and shows these arise from a regularization effect of distillation. A formal result in a linear-GD setting demonstrates bias amplification toward top eigen-directions, and extensive neural-network experiments with cross-entropy validate the phenomenon and its link to improved generalization, while also highlighting conditions under which distillation can hurt. The work provides a cohesive theory connecting gradient-descent dynamics to practical distillation behavior, offering guidance on loss-switching and when deliberate deviations may be advantageous.

Abstract

Knowledge distillation (KD) has been widely used to improve the test accuracy of a "student" network, by training it to mimic the soft probabilities of a trained "teacher" network. Yet, it has been shown in recent work that, despite being trained to fit the teacher's probabilities, the student may not only significantly deviate from the teacher probabilities, but may also outdo than the teacher in performance. Our work aims to reconcile this seemingly paradoxical observation. Specifically, we characterize the precise nature of the student-teacher deviations, and argue how they can co-occur with better generalization. First, through experiments on image and language data, we identify that these probability deviations correspond to the student systematically exaggerating the confidence levels of the teacher. Next, we theoretically and empirically establish another form of exaggeration in some simple settings: KD exaggerates the implicit bias of gradient descent in converging faster along the top eigendirections of the data. Finally, we tie these two observations together: we demonstrate that the exaggerated bias of KD can simultaneously result in both (a) the exaggeration of confidence and (b) the improved generalization of the student, thus offering a resolution to the apparent paradox. Our analysis brings existing theory and practice closer by considering the role of gradient descent in KD and by demonstrating the exaggerated bias effect in both theoretical and empirical settings.
Paper Structure (29 sections, 2 theorems, 19 equations, 26 figures, 7 tables)

This paper contains 29 sections, 2 theorems, 19 equations, 26 figures, 7 tables.

Key Result

Theorem 4.1

(informal; see §app:eigenspace for full version and proof) Let $\beta_k(t)$ and $\tilde{{\beta}}_k(t)$ respectively denote the component of the teacher and student weights along the $k$'th eigenvector of the Gram matrix $\mathbf{X}\mathbf{X}^\top$, at any time $t$. Let $k_1 < k_2$ be two indices fo

Figures (26)

  • Figure 1: (a): Distilled student exaggerates confidence of one-hot-loss trained teacher. For each training sample $(x, y)$, we plot $X=\phi( {{p}^{\textsf{te}}_{{y}^{\textsf{te}}}( x )} )$ versus $Y=\phi( {{p}^{\textsf{st}}_{{y}^{\textsf{te}}}( x )} )$, which are the teacher and student probabilities on the teacher's predicted label ${y}^{\textsf{te}}$, transformed monotonically by $\phi(u) = \log\left[ u / (1 - u) \right]$. Note that this is a density plot where higher the brightness, higher the number of datapoints with that $X$ and $Y$ value. We find that the distilled student predictions deviate from the $X=Y$ line by either underfitting teacher's low confidence points (i.e., we find $Y \leq X$ for small $X$) and/or overfitting teacher's high confidence points (i.e., $Y \geq X$ for large $X$). See §\ref{['sec:margin']} for details. (b) Distillation exaggerates implicit bias of one-hot gradient descent training. We consider an MLP trained on an MNIST-based dataset. Each plot shows the time-evolution of the $\ell_2$ norm of the first layer parameters projected onto two randomly picked eigendirections; the $\star$'s corresponds to the final parameters. First observe that the one-hot-trained teacher moves faster towards its final $X$ axis value than its final $Y$ axis value; this corroborates the well-known implicit bias of standard GD training. But crucially, we find that distillation exaggerates this bias: the student moves even faster towards its final $X$ axis value. In §\ref{['sec:reconcile']} we argue how this exaggerated bias manifests as the exaggerated confidence in Fig \ref{['fig:intro-underfitting']}.
  • Figure 2: Exaggeration of confidence in other settings. There are settings where even on test data, and for cross-architecture distillation settings, where the student exaggerates the teacher's confidence (here specifically on low-confidence points).
  • Figure 3: Reconciling the paradox: Distillation exaggerates the implicit bias of GD, which can both exaggerate confidence levels (thus causing deviations in probability) and help generalization. Note that the improved generalization is however conditioned on other confounding factors such as the teacher's training accuracy, as we discuss later in § \ref{['sec:confounding']}.
  • Figure 4: Left: Exaggeration of confidence under explicit label noise: While the teacher already achieves low confidence on points with wrong one-hot labels, the student achieves even lower confidence on these points, in both self- (top) and cross-architecture (bottom) distillation. Right: Effect of teacher's interpolation level in CIFAR-100: For an interpolating teacher (left), switching to one-hot loss in the middle of training hurts generalization, while for a non-interpolating teacher, the switch to one-hot is helpful.
  • Figure 5: Teacher-student logit plots for CIFAR100 experiments: We report plots for various distillation settings involving ResNet56, ResNet20 and MobileNet-v2 (training data on top, test data in the bottom). We find underfitting of the low-confidence points in the training set in all but the MobileNet self-distillation setting. But in all the settings, we find significant underfitting of the low-confidence points in the test dataset.
  • ...and 21 more figures

Theorems & Definitions (3)

  • Theorem 4.1
  • Theorem B.1
  • proof