Table of Contents
Fetching ...

Multi-Stage Balanced Distillation: Addressing Long-Tail Challenges in Sequence-Level Knowledge Distillation

Yuhang Zhou, Jing Zhu, Paiheng Xu, Xiaoyu Liu, Xiyao Wang, Danai Koutra, Wei Ai, Furong Huang

TL;DR

The Multi-Stage Balanced Distillation (BalDistill) framework is introduced, which iteratively balances training data within a fixed computational budget, and achieves state-of-the-art performance across diverse long-tailed datasets, enhancing both the efficiency and efficacy of the distilled models.

Abstract

Large language models (LLMs) have significantly advanced various natural language processing tasks, but deploying them remains computationally expensive. Knowledge distillation (KD) is a promising solution, enabling the transfer of capabilities from larger teacher LLMs to more compact student models. Particularly, sequence-level KD, which distills rationale-based reasoning processes instead of merely final outcomes, shows great potential in enhancing students' reasoning capabilities. However, current methods struggle with sequence level KD under long-tailed data distributions, adversely affecting generalization on sparsely represented domains. We introduce the Multi-Stage Balanced Distillation (BalDistill) framework, which iteratively balances training data within a fixed computational budget. By dynamically selecting representative head domain examples and synthesizing tail domain examples, BalDistill achieves state-of-the-art performance across diverse long-tailed datasets, enhancing both the efficiency and efficacy of the distilled models.

Multi-Stage Balanced Distillation: Addressing Long-Tail Challenges in Sequence-Level Knowledge Distillation

TL;DR

The Multi-Stage Balanced Distillation (BalDistill) framework is introduced, which iteratively balances training data within a fixed computational budget, and achieves state-of-the-art performance across diverse long-tailed datasets, enhancing both the efficiency and efficacy of the distilled models.

Abstract

Large language models (LLMs) have significantly advanced various natural language processing tasks, but deploying them remains computationally expensive. Knowledge distillation (KD) is a promising solution, enabling the transfer of capabilities from larger teacher LLMs to more compact student models. Particularly, sequence-level KD, which distills rationale-based reasoning processes instead of merely final outcomes, shows great potential in enhancing students' reasoning capabilities. However, current methods struggle with sequence level KD under long-tailed data distributions, adversely affecting generalization on sparsely represented domains. We introduce the Multi-Stage Balanced Distillation (BalDistill) framework, which iteratively balances training data within a fixed computational budget. By dynamically selecting representative head domain examples and synthesizing tail domain examples, BalDistill achieves state-of-the-art performance across diverse long-tailed datasets, enhancing both the efficiency and efficacy of the distilled models.
Paper Structure (26 sections, 1 equation, 5 figures, 10 tables, 1 algorithm)

This paper contains 26 sections, 1 equation, 5 figures, 10 tables, 1 algorithm.

Figures (5)

  • Figure 1: Overview of the proposed iterative BalDistill framework. The framework is composed of multiple stages. For each stage, we apply the balancing policy to decide the data distribution in the training batch. For head domains with sufficient data, we actively extract the examples by IFD metrics using the student model. For the tail domains, we call the teacher model to generate the synthetic examples and the corresponding rationales. The teacher model finally annotates the balanced training batch and fine-tunes the student model.
  • Figure 2: Performance of proposed method and baselines on different domains. X-axis represents the proportion of each domain, ranked from head to tail domains. Our proposed BalDistill method can achieve comparable results on head domains and outperform the baseline method on the tail domains.
  • Figure 3: Performance of proposed method BalDistill and ablated methods on head and tail domains. BalDistill (A) can achieve better results on head domains and outperform the Active FT CoT method on tail domains, which demonstrates the effectiveness of each component in our BalDistill (A) framework.
  • Figure 4: Influence of stage number choices on BalDistill across datasets. Our proposed method consistently obtains better results than the random fine-tune baseline method with varying stage numbers.
  • Figure 5: Example Dataset Distribution: The datasets we use exhibit long-tail distributions.