Table of Contents
Fetching ...

Knowledge-Augmented Reasoning Distillation for Small Language Models in Knowledge-Intensive Tasks

Minki Kang, Seanie Lee, Jinheon Baek, Kenji Kawaguchi, Sung Ju Hwang

TL;DR

This work tackles the challenge of deploying knowledge intensive reasoning in resource constrained settings by enabling small language models to perform as if they had access to large domain knowledge. It introduces Knowledge-Augmented Reasoning Distillation (KARD), which distills the reasoning capabilities of LLMs into small LMs while injecting task relevant knowledge from an external KB as non param memory. A neural reranker is added to improve passage retrieval for rationale generation at inference. Empirically, KARD yields significant gains on MedQA USMLE, StrategyQA, and OpenBookQA, with 250M parameter models surpassing fine tuned 3B models and demonstrating data and compute efficiency. The paper also provides theoretical backing showing how external memory reduces memorization requirements and analyzes retrieval, diversity, and failure modes to guide future improvements.

Abstract

Large Language Models (LLMs) have shown promising performance in knowledge-intensive reasoning tasks that require a compound understanding of knowledge. However, deployment of the LLMs in real-world applications can be challenging due to their high computational requirements and concerns on data privacy. Previous studies have focused on building task-specific small Language Models (LMs) by fine-tuning them with labeled data or distilling LLMs. However, these approaches are ill-suited for knowledge-intensive reasoning tasks due to the limited capacity of small LMs in memorizing the knowledge required. Motivated by our theoretical analysis on memorization, we propose Knowledge-Augmented Reasoning Distillation (KARD), a novel method that fine-tunes small LMs to generate rationales obtained from LLMs with augmented knowledge retrieved from an external knowledge base. Moreover, we further propose a neural reranker to obtain documents relevant to rationale generation. We empirically show that KARD significantly improves the performance of small T5 and GPT models on the challenging knowledge-intensive reasoning datasets, namely MedQA-USMLE, StrategyQA, and OpenbookQA. Notably, our method makes the 250M T5 models achieve superior performance against the fine-tuned 3B models, having 12 times larger parameters, on both MedQA-USMLE and StrategyQA benchmarks.

Knowledge-Augmented Reasoning Distillation for Small Language Models in Knowledge-Intensive Tasks

TL;DR

This work tackles the challenge of deploying knowledge intensive reasoning in resource constrained settings by enabling small language models to perform as if they had access to large domain knowledge. It introduces Knowledge-Augmented Reasoning Distillation (KARD), which distills the reasoning capabilities of LLMs into small LMs while injecting task relevant knowledge from an external KB as non param memory. A neural reranker is added to improve passage retrieval for rationale generation at inference. Empirically, KARD yields significant gains on MedQA USMLE, StrategyQA, and OpenBookQA, with 250M parameter models surpassing fine tuned 3B models and demonstrating data and compute efficiency. The paper also provides theoretical backing showing how external memory reduces memorization requirements and analyzes retrieval, diversity, and failure modes to guide future improvements.

Abstract

Large Language Models (LLMs) have shown promising performance in knowledge-intensive reasoning tasks that require a compound understanding of knowledge. However, deployment of the LLMs in real-world applications can be challenging due to their high computational requirements and concerns on data privacy. Previous studies have focused on building task-specific small Language Models (LMs) by fine-tuning them with labeled data or distilling LLMs. However, these approaches are ill-suited for knowledge-intensive reasoning tasks due to the limited capacity of small LMs in memorizing the knowledge required. Motivated by our theoretical analysis on memorization, we propose Knowledge-Augmented Reasoning Distillation (KARD), a novel method that fine-tunes small LMs to generate rationales obtained from LLMs with augmented knowledge retrieved from an external knowledge base. Moreover, we further propose a neural reranker to obtain documents relevant to rationale generation. We empirically show that KARD significantly improves the performance of small T5 and GPT models on the challenging knowledge-intensive reasoning datasets, namely MedQA-USMLE, StrategyQA, and OpenbookQA. Notably, our method makes the 250M T5 models achieve superior performance against the fine-tuned 3B models, having 12 times larger parameters, on both MedQA-USMLE and StrategyQA benchmarks.
Paper Structure (48 sections, 2 theorems, 18 equations, 6 figures, 19 tables)

This paper contains 48 sections, 2 theorems, 18 equations, 6 figures, 19 tables.

Key Result

Theorem 1

Let $N=n$. Then, any learning algorithm $\mathcal{A}$ that satisfies $\mathrm{err}_{q,n}(\mathcal{A})\le \mathrm{err}_{q,n}(\mathcal{A}_{\mathrm{OPT}})+\epsilon$ for $\epsilon=o(1)$ also satisfies $I(X; \mathcal{A}(X)|P)=\Omega(nd).$

Figures (6)

  • Figure 1: Concept. An example of a knowledge-intensive reasoning task (medical QA USMLE) on the top. On the bottom, we provide the conceptual illustration of our KARD, compared to existing reasoning distillation. On the right, we provide examples of passages retrieved with rationale and question from the external KB.
  • Figure 2: Overview of KARD. (Left, § \ref{['sec:kard']}) Illustration of training (top) and inference (bottom) of knowledge-augmented reasoning distillation, where, during training, the small LM learns to generate rationales given the training data and the retrieved knowledge by the rationale. (Right, § \ref{['sec:reranker']}) Illustration of reranker training (top) and inference (bottom). Reranker learns to prioritize the passage which has knowledge relevant to the rationale.
  • Figure 2: Experimental results on the StrategyQA and OpenbookQA dataset with T5 models T5. $\dagger$ indicates experiments with Flan-T5 having the same size. We report experimental results as in Table \ref{['tab:usmle']}.
  • Figure 3: Analysis on rationale diversity.
  • Figure 4: (a) Efficiency on training data and (b) model size. On MedQA-USMLE, we compare KARD against the fine-tuning baseline by varying either the number of training data with Flan-T5 Large or the number of parameters, including the few-shot in-context learning performance of Flan-T5 XXL (11B). (c) Considering silver documents as ground truth, we measure Hits@k on the documents retrieved by BM25 and the reranker.
  • ...and 1 more figures

Theorems & Definitions (2)

  • Theorem 1: brown2021memorization
  • Theorem 2