BiLD: Bi-directional Logits Difference Loss for Large Language Model Distillation
Minchong Li, Feng Zhou, Xiaohui Song
TL;DR
This work introduces Bi-directional Logits Difference (BiLD) loss for task-specific distillation of large language models at the logit level. By clipping to the top-$k$ logits and constructing bi-directional logit differences, BiLD filters long-tail noise and leverages internal logit ranking, formalized as $\mathcal{L}_{\rm BiLD} = \mathcal{L}_{t-\rm LD} + \mathcal{L}_{s-\rm LD}$. Across 13 NLP datasets and multiple teacher-student pairs, BiLD with top-8 logits outperforms supervised fine-tuning, vanilla KL, and several vision-based distillation methods, demonstrating robust gains and improved imitation of teacher behavior at the logit level. The approach offers a practical, high-signal alternative for efficient LLM distillation, with considerations on temperature and clipping level to balance performance and computation.
Abstract
In recent years, large language models (LLMs) have shown exceptional capabilities across various natural language processing (NLP) tasks. However, such impressive performance often comes with the trade-off of an increased parameter size, posing significant challenges for widespread deployment. Knowledge distillation (KD) provides a solution by transferring knowledge from a large teacher model to a smaller student model. In this paper, we explore the task-specific distillation of LLMs at the logit level. Our investigation reveals that the logits of fine-tuned LLMs exhibit a more extreme long-tail distribution than those from vision models, with hidden "noise" in the long tail affecting distillation performance. Furthermore, existing logits distillation methods often struggle to effectively utilize the internal ranking information from the logits. To address these, we propose the Bi-directional Logits Difference (BiLD) loss. The BiLD loss filters out the long-tail noise by utilizing only top-$k$ teacher and student logits, and leverages the internal logits ranking information by constructing logits differences. To evaluate BiLD loss, we conduct comprehensive experiments on 13 datasets using two types of LLMs. Our results show that the BiLD loss, with only the top-8 logits, outperforms supervised fine-tuning (SFT), vanilla KL loss, and five other distillation methods from both NLP and CV fields.
