Table of Contents
Fetching ...

MEG: Medical Knowledge-Augmented Large Language Models for Question Answering

Laura Cabello, Carmen Martin-Turrero, Uchenna Akujuobi, Anders Søgaard, Carlos Bobed

TL;DR

MEG introduces a parameter-efficient framework for medical knowledge augmentation of LLMs by injecting pretrained KG embeddings through a lightweight mapping network. A GraphSAGE-based KG encoder produces KGEs that are transformed into the LLM space and injected after the embedding layer, guided by a grounding module that links textual mentions to KG nodes. Training occurs in two phases: embedding transfer learning to align KGEs with the LLM and subsequent LoRA-based fine-tuning on downstream medical QA tasks, achieving substantial accuracy gains over specialized baselines on four MCQA datasets (e.g., +6.7% and +9.9% over BioMistral-7B and MediTron-7B). The approach demonstrates robust performance across base LLMs (Mistral and LLaMA-3) and shows that KGEs enrich factual grounding without full base-model retraining, suggesting a practical path for domain-specific QA in medicine.

Abstract

Question answering is a natural language understanding task that involves reasoning over both explicit context, and unstated relevant domain knowledge. Despite the high cost of training, large language models (LLMs) -- the backbone of most modern question-answering systems -- still struggle to reliably capture the nuanced relationships between concepts that are crucial for reasoning in specialized fields like medicine. In this work, we present MEG, a parameter-efficient approach for medical knowledge-augmented LLMs. MEG uses a lightweight mapping network to incorporate knowledge graph embeddings into the LLM, enabling it to leverage external knowledge in a cost-effective way. We evaluate our method on four popular medical multiple-choice datasets and show that LLMs i) can effectively interpret knowledge graph embeddings and ii) gain significant advantages from the factual grounding these embeddings provide. MEG attains an average of +6.7% and +9.9% accuracy over specialized models like BioMistral-7B and MediTron-7B, respectively. Finally, we show that MEG's performance remains robust to the choice of graph encoder.

MEG: Medical Knowledge-Augmented Large Language Models for Question Answering

TL;DR

MEG introduces a parameter-efficient framework for medical knowledge augmentation of LLMs by injecting pretrained KG embeddings through a lightweight mapping network. A GraphSAGE-based KG encoder produces KGEs that are transformed into the LLM space and injected after the embedding layer, guided by a grounding module that links textual mentions to KG nodes. Training occurs in two phases: embedding transfer learning to align KGEs with the LLM and subsequent LoRA-based fine-tuning on downstream medical QA tasks, achieving substantial accuracy gains over specialized baselines on four MCQA datasets (e.g., +6.7% and +9.9% over BioMistral-7B and MediTron-7B). The approach demonstrates robust performance across base LLMs (Mistral and LLaMA-3) and shows that KGEs enrich factual grounding without full base-model retraining, suggesting a practical path for domain-specific QA in medicine.

Abstract

Question answering is a natural language understanding task that involves reasoning over both explicit context, and unstated relevant domain knowledge. Despite the high cost of training, large language models (LLMs) -- the backbone of most modern question-answering systems -- still struggle to reliably capture the nuanced relationships between concepts that are crucial for reasoning in specialized fields like medicine. In this work, we present MEG, a parameter-efficient approach for medical knowledge-augmented LLMs. MEG uses a lightweight mapping network to incorporate knowledge graph embeddings into the LLM, enabling it to leverage external knowledge in a cost-effective way. We evaluate our method on four popular medical multiple-choice datasets and show that LLMs i) can effectively interpret knowledge graph embeddings and ii) gain significant advantages from the factual grounding these embeddings provide. MEG attains an average of +6.7% and +9.9% accuracy over specialized models like BioMistral-7B and MediTron-7B, respectively. Finally, we show that MEG's performance remains robust to the choice of graph encoder.

Paper Structure

This paper contains 32 sections, 1 equation, 5 figures, 4 tables.

Figures (5)

  • Figure 1: MEG leverages a pretrained KG encoder and an LLM. During an initial phase of training, MEG learns a mapping network to convert relevant graph features (KGEs) retrieved by the grounding module into token embeddings. During downstream fine-tuning, only a subset of the LLM's weights is updated, while the embedding layer and mapping network remain frozen. At inference, the LLM takes the text and the mapped KGEs as input and generates a response.
  • Figure 2: $f_{\text{k}}$ and $g_{\text{k}}$ are embedding transfer functions. $f_{\text{k}}$ takes a set of KGEs $X$ (i.e., node entities) as input, and outputs a mapping of $X$ to the LLM's vector space. $Y$ is the set of averaged token embeddings of entities in the LLM space. During training, $g_{\text{k}}$ prevents degenerated transformation of graph embeddings. The dashed lines indicate the input for the objective losses.
  • Figure 3: Template used to generate instructions for all QA datasets. The context is optional, depending on the dataset. At inference time, the text after [/INST] is generated by the language model.
  • Figure 4: Ablation study on KG encoder choice. Plain bars show edge classification accuracy over UMLS; stripped bars show MEG-Mistral1's zero-shot (//) and fine-tuned (\\\\) accuracy on MedQA; the dashed line represents accuracy with random embeddings; red dots mark the ratio of not valid answers (NA) in the zero-shot setting.
  • Figure 5: t-SNE visualization of the embeddings: before and after the mapping network. After mapping, the relative KGEs' structure along hierarchy levels is preserved, albeit slightly rotated (e.g., in Level 4, see the diagonal gap which hints the orientation of the blobs) and with reversed sparsity. Note the clustering effect over contextualized label embeddings: the mapped KGEs draw them to a specific region.