Table of Contents
Fetching ...

Don't Ignore the Tail: Decoupling top-K Probabilities for Efficient Language Model Distillation

Sayantan Dasgupta, Trevor Cohn, Timothy Baldwin

TL;DR

A new tail-aware divergence is proposed that decouples the contribution of the teacher model's top-K predicted probabilities from that of lower-probability predictions, while maintaining the same computational profile as the KL Divergence.

Abstract

The core learning signal used in language model distillation is the standard Kullback-Leibler (KL) divergence between the student and teacher distributions. Traditional KL divergence tends to be dominated by the next tokens with the highest probabilities, i.e., the teacher's modes, thereby diminishing the influence of less probable yet potentially informative components of the output distribution. We propose a new tail-aware divergence that decouples the contribution of the teacher model's top-K predicted probabilities from that of lower-probability predictions, while maintaining the same computational profile as the KL Divergence. Our decoupled approach reduces the impact of the teacher modes and, consequently, increases the contribution of the tail of the distribution. Experimental results demonstrate that our modified distillation method yields competitive performance in both pre-training and supervised distillation of decoder models across various datasets. Furthermore, the distillation process is efficient and can be performed with a modest academic budget for large datasets, eliminating the need for industry-scale computing.

Don't Ignore the Tail: Decoupling top-K Probabilities for Efficient Language Model Distillation

TL;DR

A new tail-aware divergence is proposed that decouples the contribution of the teacher model's top-K predicted probabilities from that of lower-probability predictions, while maintaining the same computational profile as the KL Divergence.

Abstract

The core learning signal used in language model distillation is the standard Kullback-Leibler (KL) divergence between the student and teacher distributions. Traditional KL divergence tends to be dominated by the next tokens with the highest probabilities, i.e., the teacher's modes, thereby diminishing the influence of less probable yet potentially informative components of the output distribution. We propose a new tail-aware divergence that decouples the contribution of the teacher model's top-K predicted probabilities from that of lower-probability predictions, while maintaining the same computational profile as the KL Divergence. Our decoupled approach reduces the impact of the teacher modes and, consequently, increases the contribution of the tail of the distribution. Experimental results demonstrate that our modified distillation method yields competitive performance in both pre-training and supervised distillation of decoder models across various datasets. Furthermore, the distillation process is efficient and can be performed with a modest academic budget for large datasets, eliminating the need for industry-scale computing.
Paper Structure (18 sections, 23 equations, 2 figures, 8 tables)

This paper contains 18 sections, 23 equations, 2 figures, 8 tables.

Figures (2)

  • Figure 1: KL divergence on the validation set of Regmix for vanilla KD vs. TAD. The $x$ axis shows training progress in terms of the number of tokens, and the $y$ axis shows held-out KL between the student and teacher.
  • Figure 2: Tail probability mass ($\alpha^T_K$) against $K$ for different teachers in the first, and the Next Token vs. Mode mismatch rate in percentage in the second plot, measured on the validation set of Regmix (see \ref{['sec:scratch']})