Table of Contents
Fetching ...

Knowledge Distillation in Wide Neural Networks: Risk Bound, Data Efficiency and Imperfect Teacher

Guangda Ji, Zhanxing Zhu

TL;DR

This paper theoretically analyze the knowledge distillation of a wide neural network, provides a transfer risk bound for the linearized model of the network, and proposes a metric of the task's training difficulty, called data inefficiency, which shows that for a perfect teacher, a high ratio of teacher's soft labels can be beneficial.

Abstract

Knowledge distillation is a strategy of training a student network with guide of the soft output from a teacher network. It has been a successful method of model compression and knowledge transfer. However, currently knowledge distillation lacks a convincing theoretical understanding. On the other hand, recent finding on neural tangent kernel enables us to approximate a wide neural network with a linear model of the network's random features. In this paper, we theoretically analyze the knowledge distillation of a wide neural network. First we provide a transfer risk bound for the linearized model of the network. Then we propose a metric of the task's training difficulty, called data inefficiency. Based on this metric, we show that for a perfect teacher, a high ratio of teacher's soft labels can be beneficial. Finally, for the case of imperfect teacher, we find that hard labels can correct teacher's wrong prediction, which explains the practice of mixing hard and soft labels.

Knowledge Distillation in Wide Neural Networks: Risk Bound, Data Efficiency and Imperfect Teacher

TL;DR

This paper theoretically analyze the knowledge distillation of a wide neural network, provides a transfer risk bound for the linearized model of the network, and proposes a metric of the task's training difficulty, called data inefficiency, which shows that for a perfect teacher, a high ratio of teacher's soft labels can be beneficial.

Abstract

Knowledge distillation is a strategy of training a student network with guide of the soft output from a teacher network. It has been a successful method of model compression and knowledge transfer. However, currently knowledge distillation lacks a convincing theoretical understanding. On the other hand, recent finding on neural tangent kernel enables us to approximate a wide neural network with a linear model of the network's random features. In this paper, we theoretically analyze the knowledge distillation of a wide neural network. First we provide a transfer risk bound for the linearized model of the network. Then we propose a metric of the task's training difficulty, called data inefficiency. Based on this metric, we show that for a perfect teacher, a high ratio of teacher's soft labels can be beneficial. Finally, for the case of imperfect teacher, we find that hard labels can correct teacher's wrong prediction, which explains the practice of mixing hard and soft labels.

Paper Structure

This paper contains 29 sections, 3 theorems, 27 equations, 12 figures.

Key Result

Theorem 1

(Risk bound) Given input distribution $P(x)$, training samples $\mathbf{X} = [x_1, \cdots, x_n]$, oracle weight change $\Delta_{w_*}$, zero weight change $\Delta_{w_{\mathrm{z}}}$ and accumulative angle distribution $p(\beta)$, the transfer risk is bounded by, where $\bar{\alpha}_n = \bar{\alpha}(\Delta_{w_*} - \Delta_{w_{\mathrm{z}}}, \Delta_{\hat{w}} - \Delta_{w_{\mathrm{z}}})$ and $\Delta_{\ha

Figures (12)

  • Figure 1: Effective student logits $z_{\mathrm{s,eff}}$ as a function of $z_{\mathrm{t}}$ and $y_{\mathrm{g}}$. The left and right figure shows how soft ratio $\rho$ (with $T=5.0$) and temperature $T$ (with $\rho=0.05$) can affect the shape of $z_{\mathrm{s,eff}}(z_{\mathrm{t}}, y_{\mathrm{g}})$. Each point is attained by solving Eq. \ref{['eq:effective_logits']} with first order gradient method. Solid lines show a correct teacher $y_{\mathrm{g}} = \mathds{1}\{ z_{\mathrm{t}} > 0 \}$, and dashed lines denote a wrong teacher $y_{\mathrm{g}} = \mathds{1}\{ z_{\mathrm{t}} < 0 \}$. The existence of hard label produces a discontinuity in $z_{\mathrm{s,eff}}(z_{\mathrm{t}}, y_{\mathrm{g}})$.
  • Figure 2: Left: Experimental transfer risk, plotted with respect to sample size $n$. The curve shows a power law relation, with a faster rate for pure soft distillation. Middle: Accumulative angle distribution $p(\beta)$, as part of our transfer risk bound. We split the curves into two subfigures because they change non-monotonically with respect to $\rho$. Right:$\alpha_n'=\alpha (\Delta_{\hat{w}},\Delta_{w_*})$ with respect to $n$. See Sec. \ref{['supp:experiment_detail']} in Appendix for details.
  • Figure 3: Left: Difficulty control on the number of modes. The figure shows $\mathcal{I}(n)$ of learning different Gaussian mixture function. The decreasing behavior of $\mathcal{I}(n)$ is typical for learning a noise-free smooth function. Right: Difficulty control on flip probability. The figure plots $\mathcal{I}(n)$ of learning the same function with different noise level. $p_{\mathrm{flip}}=0.5$ means a completely random sign. The noise makes these tasks so difficult to learn that $\mathcal{I}(n)\equiv 0.8$, this means $\Delta_{\hat{w}}$ will not converge. The two figures demonstrate a positive correlation between $\mathcal{I}(n)$ and task's difficulty. The dashed lines are references of a hard and easy task. The upper dashed line shows the complexity of random label $\Delta_{z} \sim \mathcal{N}(0,1)$, while the lower dashed line shows the complexity of zero function $z \equiv 0$. The later one also demonstrates that zero function is extremely easy to learn and $\Delta_{w_{\mathrm{z}}}$ can be neglected. All the results are based on the average of 20 runs.
  • Figure 4: Left: A 1-D example of teacher and student output. Left Top: Ground truth class boundary. Left Middle: Teacher's logits at different stopping epoch. The scale of teacher increases and shape becomes more detailed while training. Left Bottom: Effective student logits $y_{\mathrm{s,eff}}$ at different soft ratio $\rho$. This figure further illustrates the discontinuity in $y_{\mathrm{s,eff}}$. For a small $\rho$, students shape shows a clear similarity with label smoothing. All student share a same teacher network. Middle and Right: Data inefficiency curve of different teacher stopping epoch and soft ratio. It shows that adding hard labels to distillation increases sample complexity. See Sec. \ref{['supp:experiment_detail']} for details.
  • Figure 5: Left and Middle: Imperfect distillation on synthetic dataset and practical (CIFAR10/ResNet) dataset. These plots show that pure soft label distillation in imperfect KD is not optimal. Right:$\langle \delta_{\hat{w}_\mathrm{h}}, \Delta_{\hat{w}_{\mathrm{c}}} \rangle$ is proportional to $\partial \cos \alpha(\Delta_{\hat{w}}, \Delta_{w_{\mathrm{g}}})/ \partial(1-\rho)$. The sign of it denotes whether adding hard labels can or cannot reduce the angle between the student and oracle. The stopping epoch of teacher is positively related to teacher's generalization ability. The epoch when $\langle \delta_{\hat{w}_\mathrm{h}}, \Delta_{\hat{w}_{\mathrm{c}}} \rangle$ switches sign, is approximately when teacher outperforms the best of student network. See Sec. \ref{['supp:experiment_detail']} for details.
  • ...and 7 more figures

Theorems & Definitions (5)

  • Theorem 1
  • Definition 1
  • Theorem 2
  • Theorem 3
  • proof