Table of Contents
Fetching ...

Enhancing Knowledge Distillation of Large Language Models through Efficient Multi-Modal Distribution Alignment

Tianyu Peng, Jiajun Zhang

TL;DR

The paper tackles the challenge that large language models produce multi-modal probability distributions during inference, which standard knowledge distillation (KD) objectives struggle to teach to smaller student models. It introduces Ranking Loss based Knowledge Distillation (RLKD), a word-level ranking loss based on Spearman's SRCC that aligns the order of peak predictions between teacher and student, operating on the union of top-k predictions and compatible with existing KD losses. Empirically, RLKD improves multi-modal distribution learning in pre-training and yields significant downstream gains across tasks such as GSM8K, Dolly, and XSum, while maintaining or modestly improving top-1 performance and adding negligible computation overhead. The approach is validated across multiple baselines, datasets, and model sizes, demonstrating robust improvements in peak-prediction alignment and downstream task performance, and offering a practical path to more efficient KD for LLMs.

Abstract

Knowledge distillation (KD) is an effective model compression method that can transfer the internal capabilities of large language models (LLMs) to smaller ones. However, the multi-modal probability distribution predicted by teacher LLMs causes difficulties for student models to learn. In this paper, we first demonstrate the importance of multi-modal distribution alignment with experiments and then highlight the inefficiency of existing KD approaches in learning multi-modal distributions. To address this problem, we propose Ranking Loss based Knowledge Distillation (RLKD), which encourages the consistency of the ranking of peak predictions between the teacher and student models. By incorporating word-level ranking loss, we ensure excellent compatibility with existing distillation objectives while fully leveraging the fine-grained information between different categories in peaks of two predicted distribution. Experimental results demonstrate that our method enables the student model to better learn the multi-modal distributions of the teacher model, leading to a significant performance improvement in various downstream tasks.

Enhancing Knowledge Distillation of Large Language Models through Efficient Multi-Modal Distribution Alignment

TL;DR

The paper tackles the challenge that large language models produce multi-modal probability distributions during inference, which standard knowledge distillation (KD) objectives struggle to teach to smaller student models. It introduces Ranking Loss based Knowledge Distillation (RLKD), a word-level ranking loss based on Spearman's SRCC that aligns the order of peak predictions between teacher and student, operating on the union of top-k predictions and compatible with existing KD losses. Empirically, RLKD improves multi-modal distribution learning in pre-training and yields significant downstream gains across tasks such as GSM8K, Dolly, and XSum, while maintaining or modestly improving top-1 performance and adding negligible computation overhead. The approach is validated across multiple baselines, datasets, and model sizes, demonstrating robust improvements in peak-prediction alignment and downstream task performance, and offering a practical path to more efficient KD for LLMs.

Abstract

Knowledge distillation (KD) is an effective model compression method that can transfer the internal capabilities of large language models (LLMs) to smaller ones. However, the multi-modal probability distribution predicted by teacher LLMs causes difficulties for student models to learn. In this paper, we first demonstrate the importance of multi-modal distribution alignment with experiments and then highlight the inefficiency of existing KD approaches in learning multi-modal distributions. To address this problem, we propose Ranking Loss based Knowledge Distillation (RLKD), which encourages the consistency of the ranking of peak predictions between the teacher and student models. By incorporating word-level ranking loss, we ensure excellent compatibility with existing distillation objectives while fully leveraging the fine-grained information between different categories in peaks of two predicted distribution. Experimental results demonstrate that our method enables the student model to better learn the multi-modal distributions of the teacher model, leading to a significant performance improvement in various downstream tasks.
Paper Structure (25 sections, 2 equations, 6 figures, 11 tables)

This paper contains 25 sections, 2 equations, 6 figures, 11 tables.

Figures (6)

  • Figure 1: An theoretical example illustrates the situations that can arise when using KL or RKL as distillation objective to fit multi-modal distribution.
  • Figure 2: The degree of consistency between different models and the peak predictions of the 70B model. The horizontal axis represents the range of top-k predictions. For better presentation, we set the vertical axis as the difference between the CR or MOR of the current model and the corresponding results of the 13B model.
  • Figure 3: Degree of agreement between student model and teacher model peak predictions after 20 epochs under existing KD objectives.
  • Figure 4: Comparison of computational objects on peak predictions. The black lines represent existing distillation objectives and the red lines represent our method.
  • Figure 5: Improvement in the learning ability of multi-modal distributions for existing distillation objectives by introducing ranking loss in the pre-training task. We average the results of the five distillation objectives before and after adding the ranking loss. The red area indicates the improved parts.
  • ...and 1 more figures