Knowledge-Augmented Reasoning Distillation for Small Language Models in Knowledge-Intensive Tasks
Minki Kang, Seanie Lee, Jinheon Baek, Kenji Kawaguchi, Sung Ju Hwang
TL;DR
This work tackles the challenge of deploying knowledge intensive reasoning in resource constrained settings by enabling small language models to perform as if they had access to large domain knowledge. It introduces Knowledge-Augmented Reasoning Distillation (KARD), which distills the reasoning capabilities of LLMs into small LMs while injecting task relevant knowledge from an external KB as non param memory. A neural reranker is added to improve passage retrieval for rationale generation at inference. Empirically, KARD yields significant gains on MedQA USMLE, StrategyQA, and OpenBookQA, with 250M parameter models surpassing fine tuned 3B models and demonstrating data and compute efficiency. The paper also provides theoretical backing showing how external memory reduces memorization requirements and analyzes retrieval, diversity, and failure modes to guide future improvements.
Abstract
Large Language Models (LLMs) have shown promising performance in knowledge-intensive reasoning tasks that require a compound understanding of knowledge. However, deployment of the LLMs in real-world applications can be challenging due to their high computational requirements and concerns on data privacy. Previous studies have focused on building task-specific small Language Models (LMs) by fine-tuning them with labeled data or distilling LLMs. However, these approaches are ill-suited for knowledge-intensive reasoning tasks due to the limited capacity of small LMs in memorizing the knowledge required. Motivated by our theoretical analysis on memorization, we propose Knowledge-Augmented Reasoning Distillation (KARD), a novel method that fine-tunes small LMs to generate rationales obtained from LLMs with augmented knowledge retrieved from an external knowledge base. Moreover, we further propose a neural reranker to obtain documents relevant to rationale generation. We empirically show that KARD significantly improves the performance of small T5 and GPT models on the challenging knowledge-intensive reasoning datasets, namely MedQA-USMLE, StrategyQA, and OpenbookQA. Notably, our method makes the 250M T5 models achieve superior performance against the fine-tuned 3B models, having 12 times larger parameters, on both MedQA-USMLE and StrategyQA benchmarks.
