Mixed Distillation Helps Smaller Language Model Better Reasoning
Chenglin Li, Qianglong Chen, Liangyue Li, Caiyu Wang, Yicheng Li, Zulong Chen, Yin Zhang
TL;DR
Large-language models offer strong reasoning but are costly to deploy. The paper introduces Mixed Distillation (MD), which distills both Chain of Thought (CoT) and Program of Thought (PoT) from LLMs into smaller models using a multi-task learning framework and self-consistency inference. MD improves both single-path and multi-path reasoning, outperforming two independently distilled baselines and achieving competitive or superior results on SVAMP and related benchmarks with smaller backbones. The work demonstrates PoT as a valuable supervisory signal and provides a practical path toward efficient, reasoning-capable, small models.
Abstract
While large language models (LLMs) have demonstrated exceptional performance in recent natural language processing (NLP) tasks, their deployment poses substantial challenges due to high computational and memory demands in real-world applications. Recent studies have focused on enhancing smaller models through knowledge distillation from LLMs, yielding promising results. However, these models often struggle to match the performance of LLMs, especially in tasks that require reasoning. In this work, we introduce Mixed Distillation (MD) framework, which capitalizes on the strengths of Program of Thought (PoT) and Chain of Thought (CoT) capabilities within LLMs, combining multiple prompting techniques and distilling these capabilities into smaller models. Our experimental results show that MD significantly enhances the single-path and multi-path reasoning ability of smaller models in various tasks. In terms of accuracy and generality of reasoning tasks, the model generated by it exceeds the comprehensive performance of two individually distilled models. Notably, LLaMA2-7B and CodeLlama-7B using MD achieved remarkable improvements of (84.5%) and (85.5%), respectively, outperforming GPT-3.5-Turbo by (2.5%) and (3.5%), on the SVAMP benchmark.
