Table of Contents
Fetching ...

Progressive distillation induces an implicit curriculum

Abhishek Panigrahi, Bingbin Liu, Sadhika Malladi, Andrej Risteski, Surbhi Goel

TL;DR

This work identifies an implicit curriculum as one mechanism through which progressive distillation accelerates the student's learning, and extends this investigation to Transformers trained on probabilistic context-free grammars and real-world pre-training datasets.

Abstract

Knowledge distillation leverages a teacher model to improve the training of a student model. A persistent challenge is that a better teacher does not always yield a better student, to which a common mitigation is to use additional supervision from several ``intermediate'' teachers. One empirically validated variant of this principle is progressive distillation, where the student learns from successive intermediate checkpoints of the teacher. Using sparse parity as a sandbox, we identify an implicit curriculum as one mechanism through which progressive distillation accelerates the student's learning. This curriculum is available only through the intermediate checkpoints but not the final converged one, and imparts both empirical acceleration and a provable sample complexity benefit to the student. We then extend our investigation to Transformers trained on probabilistic context-free grammars (PCFGs) and real-world pre-training datasets (Wikipedia and Books). Through probing the teacher model, we identify an analogous implicit curriculum where the model progressively learns features that capture longer context. Our theoretical and empirical findings on sparse parity, complemented by empirical observations on more complex tasks, highlight the benefit of progressive distillation via implicit curriculum across setups.

Progressive distillation induces an implicit curriculum

TL;DR

This work identifies an implicit curriculum as one mechanism through which progressive distillation accelerates the student's learning, and extends this investigation to Transformers trained on probabilistic context-free grammars and real-world pre-training datasets.

Abstract

Knowledge distillation leverages a teacher model to improve the training of a student model. A persistent challenge is that a better teacher does not always yield a better student, to which a common mitigation is to use additional supervision from several ``intermediate'' teachers. One empirically validated variant of this principle is progressive distillation, where the student learns from successive intermediate checkpoints of the teacher. Using sparse parity as a sandbox, we identify an implicit curriculum as one mechanism through which progressive distillation accelerates the student's learning. This curriculum is available only through the intermediate checkpoints but not the final converged one, and imparts both empirical acceleration and a provable sample complexity benefit to the student. We then extend our investigation to Transformers trained on probabilistic context-free grammars (PCFGs) and real-world pre-training datasets (Wikipedia and Books). Through probing the teacher model, we identify an analogous implicit curriculum where the model progressively learns features that capture longer context. Our theoretical and empirical findings on sparse parity, complemented by empirical observations on more complex tasks, highlight the benefit of progressive distillation via implicit curriculum across setups.
Paper Structure (72 sections, 13 theorems, 49 equations, 32 figures, 1 algorithm)

This paper contains 72 sections, 13 theorems, 49 equations, 32 figures, 1 algorithm.

Key Result

Theorem 3.2

Consider learning $(d, k)$-sparse parity with a student model of size $\Tilde{m} = \Tilde{\Theta}(2^k)$, where $\Tilde{\cdot}$ hides polylog factors in $d,k$. Suppose the teacher has a loss $\mathcal{O}(\epsilon)$ for some small $\epsilon > 0$. Then, the total sample complexity needed for the studen

Figures (32)

  • Figure 1: Progressive distillation accelerates training. Left: MLP on $(100,6)$-sparse parity (\ref{['def:sparse_parity']}), with width-50k teachers and width-100 students. Progressive distillation checkpoints are at 100k-step intervals, and one-shot checkpoint uses the final (20M-step) checkpoint. Middle: Transformer on $(100,6)$-sparse parity, with 32-head teachers and 4-head students. Progressive distillation checkpoints are at 10k-step intervals, and the one-shot checkpoint is at 250k steps. Right: Transformers on PCFG (\ref{['sec:pcfg_probing']}), with 32-head teachers and 8-head students using BERT-style masked prediction. Progressive distillation uses $8$ intermediate checkpoints.
  • Figure 2: Implicit curriculum for $(100, 6)$-sparse parity. We compare 3 candidate intermediate checkpoints, labeled as ①, ②, ③, corresponding to 9.7M, 10.2M, and 10.8M steps, or the beginning, middle, and end of the teacher's phase transition. Left: Teacher's accuracy throughout training. Middle: During the phase transition, $f_{\mathcal{T}}$ is much more strongly correlated with in-support variables ($x_1, \cdots, x_6$ in this case) than with off-support variables. Right: Only candidate ② (i.e., during phase transition) enables $(2, 1M)$-progressive distillation to reach $100\%$ accuracy. We use width-50k teachers and width-100 students; \ref{['fig:interm_ckpt_6feat_width100']} shows similar results for width-1000 students.
  • Figure 3: An example of a PCFG tree ${\rm T}(\mathbf{x})$ that generates $\mathbf{x}=$"The cat ran away". "The cat" is an example of level-2 span, and "cat" is as a boundary token for the spans of both the level-1 non-terminal Noun and the level-2 non-terminal Noun Phrase.
  • Figure 4: BERT on the PCFG cfg3b. Left: A 32-head teacher's loss exhibits three distinct phases: ① an initial phase with little change, ② a middle phase with a rapid drop, and ③ a final plateauing phase until the end of training. The triangles mark the selected checkpoints for progressive distillation, with the first teacher checkpoint (denoted by $C_1$) located at the middle of phase ②. Middle: $M_{\text{robust}}$ across training, which peaks at $C_1$. The model gets more robust to shorter $n$-gram perturbation as training progresses. The median is taken over the input sequences. Right: A 8-head student's final accuracy with $(2, T)$-progressive distillation after $4000$ total training steps. The $x$-axis marks the choice of the first teacher checkpoint. $T$ is grid-searched over $\{500, 1000, 2000\}$. The best performance is obtained by choosing $C_1$. Although results in the plots are for a single training run of the teacher, similar behaviors occur robustly across random seeds.
  • Figure 5: Comparisons on a $8$-attention head BERT model. (Left) $M_{\text{close}}$ for different $n$-grams. Progressive distillation has a lower $M_{\text{close}}$ with longer $n$-gram context. (Middle) $M_{\text{robust}}$ for different $n$-grams. Progressive distillation has a lower $M_{\text{robust}}$ for all $n$-gram contexts. (Right) Probe performance to predict the non-terminals (NTs) (\ref{['def:pcfg_task']}). Progressive distilled student performs better when probed for higher level non-terminals in its contextual embeddings.
  • ...and 27 more figures

Theorems & Definitions (28)

  • Definition 2.1: $( C_{\mathcal{T}} , { \mathcal{D} })$-progressive distillation
  • Definition 2.2: $({N}, T)$-progressive distillation
  • Definition 3.1: $(d, k)$-sparse parity task
  • Theorem 3.2: Informal version of \ref{['thm:sample_complexity']}
  • proof : Proof sketch
  • Definition 4.1: Masked prediction task with mask rate $p$
  • Definition 4.2: $n$-gram neighboring context
  • Definition 4.3: PCFG non-terminal prediction task
  • Theorem B.1: Sample complexity benefits with progressive distillation
  • Lemma B.2: Single step gradient descent, adapted from Claims 1, 2 in BarakEGKMZ22
  • ...and 18 more