Enhancing Knowledge Distillation of Large Language Models through Efficient Multi-Modal Distribution Alignment
Tianyu Peng, Jiajun Zhang
TL;DR
The paper tackles the challenge that large language models produce multi-modal probability distributions during inference, which standard knowledge distillation (KD) objectives struggle to teach to smaller student models. It introduces Ranking Loss based Knowledge Distillation (RLKD), a word-level ranking loss based on Spearman's SRCC that aligns the order of peak predictions between teacher and student, operating on the union of top-k predictions and compatible with existing KD losses. Empirically, RLKD improves multi-modal distribution learning in pre-training and yields significant downstream gains across tasks such as GSM8K, Dolly, and XSum, while maintaining or modestly improving top-1 performance and adding negligible computation overhead. The approach is validated across multiple baselines, datasets, and model sizes, demonstrating robust improvements in peak-prediction alignment and downstream task performance, and offering a practical path to more efficient KD for LLMs.
Abstract
Knowledge distillation (KD) is an effective model compression method that can transfer the internal capabilities of large language models (LLMs) to smaller ones. However, the multi-modal probability distribution predicted by teacher LLMs causes difficulties for student models to learn. In this paper, we first demonstrate the importance of multi-modal distribution alignment with experiments and then highlight the inefficiency of existing KD approaches in learning multi-modal distributions. To address this problem, we propose Ranking Loss based Knowledge Distillation (RLKD), which encourages the consistency of the ranking of peak predictions between the teacher and student models. By incorporating word-level ranking loss, we ensure excellent compatibility with existing distillation objectives while fully leveraging the fine-grained information between different categories in peaks of two predicted distribution. Experimental results demonstrate that our method enables the student model to better learn the multi-modal distributions of the teacher model, leading to a significant performance improvement in various downstream tasks.
