Table of Contents
Fetching ...

Efficient Knowledge Distillation via Curriculum Extraction

Shivam Gupta, Sushrut Karmalkar

TL;DR

This work introduces Curriculum Extraction, a practical method to distill knowledge from a fully trained teacher without storing intermediate checkpoints. By sequentially training the student on random projections of each teacher layer and then finetuning on the teacher’s final outputs, the approach creates a coarse-to-fine learning curriculum that reduces sample complexity and per-iteration cost. Theoretical guarantees are provided for learning sparse parity functions, demonstrating substantial gains over one-shot distillation and competitive performance with progressive distillation under certain conditions; empirical results extend to transformers and BERT-style language modeling, with improvements in both synthetic parity tasks and real data like Wikipedia. Overall, Curriculum Extraction offers a scalable, checkpoint-free pathway to efficient distillation across architectures and data regimes, with strong implications for deploying compact yet capable models in practice.

Abstract

Knowledge distillation is a technique used to train a small student network using the output generated by a large teacher network, and has many empirical advantages~\citep{Hinton2015DistillingTK}. While the standard one-shot approach to distillation only uses the output of the final teacher network, recent work~\citep{panigrahi2024progressive} has shown that using intermediate checkpoints from the teacher's training process as an implicit ``curriculum'' for progressive distillation can significantly speed up training. However, such schemes require storing these checkpoints, and often require careful selection of the intermediate checkpoints to train on, which can be impractical for large-scale training. In this paper, we show that a curriculum can be \emph{extracted} from just the fully trained teacher network, and that this extracted curriculum can give similar efficiency benefits to those of progressive distillation. Our extraction scheme is natural; we use a random projection of the hidden representations of the teacher network to progressively train the student network, before training using the output of the full network. We show that our scheme significantly outperforms one-shot distillation and achieves a performance similar to that of progressive distillation for learning sparse parities with two-layer networks, and provide theoretical guarantees for this setting. Additionally, we show that our method outperforms one-shot distillation even when using transformer-based architectures, both for sparse-parity learning, and language modeling tasks.

Efficient Knowledge Distillation via Curriculum Extraction

TL;DR

This work introduces Curriculum Extraction, a practical method to distill knowledge from a fully trained teacher without storing intermediate checkpoints. By sequentially training the student on random projections of each teacher layer and then finetuning on the teacher’s final outputs, the approach creates a coarse-to-fine learning curriculum that reduces sample complexity and per-iteration cost. Theoretical guarantees are provided for learning sparse parity functions, demonstrating substantial gains over one-shot distillation and competitive performance with progressive distillation under certain conditions; empirical results extend to transformers and BERT-style language modeling, with improvements in both synthetic parity tasks and real data like Wikipedia. Overall, Curriculum Extraction offers a scalable, checkpoint-free pathway to efficient distillation across architectures and data regimes, with strong implications for deploying compact yet capable models in practice.

Abstract

Knowledge distillation is a technique used to train a small student network using the output generated by a large teacher network, and has many empirical advantages~\citep{Hinton2015DistillingTK}. While the standard one-shot approach to distillation only uses the output of the final teacher network, recent work~\citep{panigrahi2024progressive} has shown that using intermediate checkpoints from the teacher's training process as an implicit ``curriculum'' for progressive distillation can significantly speed up training. However, such schemes require storing these checkpoints, and often require careful selection of the intermediate checkpoints to train on, which can be impractical for large-scale training. In this paper, we show that a curriculum can be \emph{extracted} from just the fully trained teacher network, and that this extracted curriculum can give similar efficiency benefits to those of progressive distillation. Our extraction scheme is natural; we use a random projection of the hidden representations of the teacher network to progressively train the student network, before training using the output of the full network. We show that our scheme significantly outperforms one-shot distillation and achieves a performance similar to that of progressive distillation for learning sparse parities with two-layer networks, and provide theoretical guarantees for this setting. Additionally, we show that our method outperforms one-shot distillation even when using transformer-based architectures, both for sparse-parity learning, and language modeling tasks.

Paper Structure

This paper contains 44 sections, 18 theorems, 74 equations, 6 figures, 2 algorithms.

Key Result

Theorem 1.1

Let $d \geq \tilde{\Omega}(k^4)$. Consider learning $d$-dimensional $k$-sparse parity with a student model of size $\tilde{\Theta}(2^{O(k)})$, where $\tilde{O}, \tilde{\Theta}$ hides polylog factors in $d,k$. Suppose the teacher has a loss $\epsilon/C$ for some sufficiently large constant $C > 0$ an

Figures (6)

  • Figure 1: Our curriculum extraction method trains the student model in a layer-wise fashion. Student layers are sequentially aligned to a random projection of the corresponding teacher layer’s hidden representation using the Mean Squared Error (MSE). After aligning layers, the student is trained on the teacher’s output logits via the KL Divergence loss.
  • Figure 2: In-support and out-of-support correlations. A two-layer MLP trained on 100-dimensional 6-sparse parity data exhibits distinct in-support (red) and out-of-support (blue) correlations of $(Af_t^{(1)})(\mathbf{x})$ with $x_j$ for the random projection $A \in \mathbb R^{1 \times m_t}$. When $j$ is in the support, the correlations show significantly larger standard deviations compared to when $j$ is outside the support.
  • Figure 3: Comparing Curriculum Extraction and One-Shot Distillation. We show three tasks for which curriculum extraction outperforms one-shot distillation: (a) A two-layer MLP trained on 100-dimensional 6-sparse parity, with a teacher hidden dimension of 50k and a student hidden dimension of 100. (b) A transformer trained on 100-dimensional 6-sparse parity, using 256-dimensional embeddings, where the teacher has 32 attention heads and the student has 4. (c) A BERT-large model fine-tuned on the Wikipedia dataset, with the teacher using 768-dimensional embeddings, 12 attention heads, and 12 transformer blocks, while the student reduces embeddings to 256 dimensions and attention heads to 4. The dashed vertical lines indicate the iterations where the layer being distilled is changed in the case of curriculum extraction, and a change in teacher checkpoint in the case of progressive distillation.
  • Figure 4: PCFG Experiments on BERT. The dashed vertical lines indicate iterations where the layer being distilled from is changed. (a) Soon after the final checkpoint, curriculum extraction achieves a larger accuracy in the same number of FLOPs, when compared to one-shot distillation. (b) We compare curriculum extraction to one-shot distillation across three models trained with two, three, and four-stage curricula at 4000, 6000, and 8000 iterations, respectively. Curriculum extraction consistently outperforms one-shot distillation at all scales. (c) We compare curriculum extraction performance by varying the number of layers. With a fixed budget of 6000 iterations (2700 for extraction, 3300 for full network training), extracting from three layers outperforms one-shot distillation, and using a single layer.
  • Figure 5: cfg3b from allen2023physics. Vocabulary is $\{1,2,3\}$. Indentation reflects production hierarchy.
  • ...and 1 more figures

Theorems & Definitions (33)

  • Theorem 1.1: Main, Informal
  • Definition 3.1: Curriculum Extraction Scheme
  • Definition 4.2: Symmetric Initialization
  • Theorem 4.3: Curriculum Extraction Requires Fewer Samples
  • Lemma 4.4: Correlation Gap (Informal)
  • Definition B.1: Fourier Expansion of a Boolean Function
  • Lemma B.2: Properties of $\phi_b(t)$
  • proof
  • Theorem B.3: Berry--Esseen
  • Theorem B.4: Anticoncentration for Rademacher Sums
  • ...and 23 more