Table of Contents
Fetching ...

Exploring and Enhancing the Transfer of Distribution in Knowledge Distillation for Autoregressive Language Models

Jun Rao, Xuebo Liu, Zepeng Lin, Liang Ding, Jing Li, Dacheng Tao, Min Zhang

TL;DR

Online Knowledge Distillation is introduced, where the teacher network integrates small online modules to concurrently train with the student model, thereby allowing dynamic adaptation to the student's distribution to make distillation better.

Abstract

Knowledge distillation (KD) is a technique that compresses large teacher models by training smaller student models to mimic them. The success of KD in auto-regressive language models mainly relies on Reverse KL for mode-seeking and student-generated output (SGO) to combat exposure bias. Our theoretical analyses and experimental validation reveal that while Reverse KL effectively mimics certain features of the teacher distribution, it fails to capture most of its behaviors. Conversely, SGO incurs higher computational costs and presents challenges in optimization, particularly when the student model is significantly smaller than the teacher model. These constraints are primarily due to the immutable distribution of the teacher model, which fails to adjust adaptively to models of varying sizes. We introduce Online Knowledge Distillation (OKD), where the teacher network integrates small online modules to concurrently train with the student model. This strategy abolishes the necessity for on-policy sampling and merely requires minimal updates to the parameters of the teacher's online module during training, thereby allowing dynamic adaptation to the student's distribution to make distillation better. Extensive results across multiple generation datasets show that OKD achieves or exceeds the performance of leading methods in various model architectures and sizes, reducing training time by up to fourfold.

Exploring and Enhancing the Transfer of Distribution in Knowledge Distillation for Autoregressive Language Models

TL;DR

Online Knowledge Distillation is introduced, where the teacher network integrates small online modules to concurrently train with the student model, thereby allowing dynamic adaptation to the student's distribution to make distillation better.

Abstract

Knowledge distillation (KD) is a technique that compresses large teacher models by training smaller student models to mimic them. The success of KD in auto-regressive language models mainly relies on Reverse KL for mode-seeking and student-generated output (SGO) to combat exposure bias. Our theoretical analyses and experimental validation reveal that while Reverse KL effectively mimics certain features of the teacher distribution, it fails to capture most of its behaviors. Conversely, SGO incurs higher computational costs and presents challenges in optimization, particularly when the student model is significantly smaller than the teacher model. These constraints are primarily due to the immutable distribution of the teacher model, which fails to adjust adaptively to models of varying sizes. We introduce Online Knowledge Distillation (OKD), where the teacher network integrates small online modules to concurrently train with the student model. This strategy abolishes the necessity for on-policy sampling and merely requires minimal updates to the parameters of the teacher's online module during training, thereby allowing dynamic adaptation to the student's distribution to make distillation better. Extensive results across multiple generation datasets show that OKD achieves or exceeds the performance of leading methods in various model architectures and sizes, reducing training time by up to fourfold.
Paper Structure (33 sections, 17 equations, 12 figures, 6 tables, 1 algorithm)

This paper contains 33 sections, 17 equations, 12 figures, 6 tables, 1 algorithm.

Figures (12)

  • Figure 1: An overview of the KD methods in auto-regressive LMs.
  • Figure 2: Comparison of metrics such as standard deviation for logits (STD), teacher-student prediction similarity (KL), and TOP 1 prediction agreement for different optimization functions.
  • Figure 3: ROUGE-L scores for the validation set across the different methods.
  • Figure 4: Plot of validation loss values across each validation iteration.
  • Figure 5: Comparison of metrics such as standard deviation for logits (STD), teacher-student prediction similarity (KL), and TOP 1 prediction agreement for different optimization functions. Our method achieves better results across samples of varying difficulties and teacher-student combinations.
  • ...and 7 more figures