Table of Contents
Fetching ...

Self-Data Distillation for Recovering Quality in Pruned Large Language Models

Vithursan Thangarasa, Ganesh Venkatesh, Mike Lasby, Nish Sinnadurai, Sean Lie

TL;DR

This work tackles the quality degradation that accompanies structured pruning of large language models by introducing self-data distilled fine-tuning, which uses the unpruned model to generate a distilled dataset that preserves the base model's knowledge and alignment. The method outperforms standard supervised fine-tuning across multiple pruned models and tasks, and its benefits extend to speculative decoding, where token acceptance rates improve and latency decreases. A key extension, model merging via SLERP, yields further quality retention, with recoveries reaching into the low 90s for moderate pruning, and gains scale with larger distillation datasets. Overall, the approach enables efficient pruning without sacrificing critical reasoning capabilities, offering practical impact for deploying affordable yet capable LLMs.

Abstract

Large language models have driven significant progress in natural language processing, but their deployment requires substantial compute and memory resources. As models scale, compression techniques become essential for balancing model quality with computational efficiency. Structured pruning, which removes less critical components of the model, is a promising strategy for reducing complexity. However, one-shot pruning often results in significant quality degradation, particularly in tasks requiring multi-step reasoning. To recover lost quality, supervised fine-tuning (SFT) is commonly applied, but it can lead to catastrophic forgetting by shifting the model's learned data distribution. Therefore, addressing the degradation from both pruning and SFT is essential to preserve the original model's quality. In this work, we utilize self-data distilled fine-tuning to address these challenges. Our approach leverages the original, unpruned model to generate a distilled dataset that preserves semantic richness and mitigates catastrophic forgetting by maintaining alignment with the base model's knowledge. Empirically, we demonstrate that self-data distillation consistently outperforms standard SFT, improving average accuracy by up to 8% on the HuggingFace OpenLLM Leaderboard v1. Specifically, when pruning six decoder blocks on Llama3.1-8B Instruct (i.e., 32 to 26 layers, reducing the model size from 8.03B to 6.72B parameters), our method retains 91.2% of the original model's accuracy compared to 81.7% with SFT, while reducing real-world FLOPs by 16.3%. Furthermore, combining self-data distilled models through model merging yields enhanced quality retention. Additionally, leveraging these pruned models in speculative decoding increases token acceptance rates, thereby improving inference efficiency in applied settings.

Self-Data Distillation for Recovering Quality in Pruned Large Language Models

TL;DR

This work tackles the quality degradation that accompanies structured pruning of large language models by introducing self-data distilled fine-tuning, which uses the unpruned model to generate a distilled dataset that preserves the base model's knowledge and alignment. The method outperforms standard supervised fine-tuning across multiple pruned models and tasks, and its benefits extend to speculative decoding, where token acceptance rates improve and latency decreases. A key extension, model merging via SLERP, yields further quality retention, with recoveries reaching into the low 90s for moderate pruning, and gains scale with larger distillation datasets. Overall, the approach enables efficient pruning without sacrificing critical reasoning capabilities, offering practical impact for deploying affordable yet capable LLMs.

Abstract

Large language models have driven significant progress in natural language processing, but their deployment requires substantial compute and memory resources. As models scale, compression techniques become essential for balancing model quality with computational efficiency. Structured pruning, which removes less critical components of the model, is a promising strategy for reducing complexity. However, one-shot pruning often results in significant quality degradation, particularly in tasks requiring multi-step reasoning. To recover lost quality, supervised fine-tuning (SFT) is commonly applied, but it can lead to catastrophic forgetting by shifting the model's learned data distribution. Therefore, addressing the degradation from both pruning and SFT is essential to preserve the original model's quality. In this work, we utilize self-data distilled fine-tuning to address these challenges. Our approach leverages the original, unpruned model to generate a distilled dataset that preserves semantic richness and mitigates catastrophic forgetting by maintaining alignment with the base model's knowledge. Empirically, we demonstrate that self-data distillation consistently outperforms standard SFT, improving average accuracy by up to 8% on the HuggingFace OpenLLM Leaderboard v1. Specifically, when pruning six decoder blocks on Llama3.1-8B Instruct (i.e., 32 to 26 layers, reducing the model size from 8.03B to 6.72B parameters), our method retains 91.2% of the original model's accuracy compared to 81.7% with SFT, while reducing real-world FLOPs by 16.3%. Furthermore, combining self-data distilled models through model merging yields enhanced quality retention. Additionally, leveraging these pruned models in speculative decoding increases token acceptance rates, thereby improving inference efficiency in applied settings.

Paper Structure

This paper contains 37 sections, 7 equations, 4 figures, 8 tables, 1 algorithm.

Figures (4)

  • Figure 1: Average quality recovery (%) of pruned Llama3.1-8B Instruct models relative to the unpruned baseline, across varying prune block sizes on the HuggingFace OpenLLM Leaderboard v1. The plot compares no fine-tuning, supervised fine-tuning, and self-data distilled fine-tuning using the OpenMathInstruct dataset. While model quality declines with prune block sizes, self-data distillation consistently achieves superior recovery.
  • Figure 2: Quality of pruned Llama3.1-8B Instruct models across various datasets and pruning block sizes. The plots show average accuracy across MMLU, GSM8k, ARC-C tasks for GSM8k, OpenMathInstruct, Dolly, and Alpaca under three strategies: Self-Data FT, SFT, and No FT. Self-Data FT consistently outperforms SFT and No FT, with the largest gains using OpenMathInstruct (50k).
  • Figure 3: Comparison of Llama3.1-8B Instruct using (left) angular cosine and (right) block influence (BI) score metrics. Both metrics highlight inherent redundancy in the middle layers of Llama3.1-8B Instruct, suggesting that these layers can be pruned with minimal impact on overall model quality.
  • Figure 4: Distribution of embedding similarities on GSM8k test dataset. We present the distribution of embedding similarities after fine-tuning a structurally pruned variant of the Llama3.1-8B Instruct model on 50k samples of OpenMathInstruct. We compute the cosine similarity between the sentence embeddings of the pruned models and those generated by the original Llama3.1-8B Instruct model. The plots show: (left) 28 decoder layers (prune block size = 4), (center) 26 decoder layers (prune block size = 6), and (right) 24 decoder layers (prune block size = 8). Self-Data Distilled Fine-Tuning (Self-Data FT) achieves higher similarity to the original baseline model, indicating a reduced distribution shift compared to Supervised Fine-Tuning (SFT).