Table of Contents
Fetching ...

JMLR: Joint Medical LLM and Retrieval Training for Enhancing Reasoning and Professional Question Answering Capability

Junda Wang, Zhichao Yang, Zonghai Yao, Hong Yu

TL;DR

The paper tackles hallucinations and domain knowledge gaps in medical QA by introducing JMLR, a framework that jointly trains a domain-specific retriever and an LLM to fetch clinically relevant guidelines during fine-tuning. It introduces an LLM-Rank loss to align the retriever with improvements in LLM performance, enabling end-to-end optimization and efficiency gains. On multiple medical QA benchmarks, JMLR-13B achieves state-of-the-art results, surpassing Meditron and RAG variants while drastically reducing training time. The work demonstrates that tightly coupled retrieval and language-model training can boost medical reasoning and factual accuracy, with thoughtful consideration of privacy, bias, and broader societal impacts for real-world deployment.

Abstract

Large Language Models (LLMs) have demonstrated a remarkable potential in medical knowledge acquisition and question-answering. However, LLMs can potentially hallucinate and yield factually incorrect outcomes, even with domain-specific pretraining. Previously, retrieval augmented generation (RAG) has limited success in addressing hallucinations. Unlike previous methods in RAG where the retrieval model was trained separately from the LLM, we introduce JMLR (for Jointly trains LLM and information Retrieval) during the fine-tuning phase. The synchronized training mechanism enhances JMLR's ability to retrieve clinical guidelines and leverage medical knowledge to reason and answer questions and reduces the demand for computational resources. We evaluated JMLR on the important medical question-answering application. Our experimental results demonstrate that JMLR-13B (70.5%) outperforms a previous state-of-the-art open-source model using conventional pre-training and fine-tuning Meditron-70B (68.9%) and Llama2-13B with RAG (67.7%) on a medical question-answering dataset. Comprehensive evaluations reveal JMLR-13B enhances reasoning quality and reduces hallucinations better than Claude3-Opus. Additionally, JMLR-13B (148 GPU hours) also trains much faster than Meditron-70B (42630 GPU hours). Through this work, we provide a new and efficient knowledge enhancement method for healthcare, demonstrating the potential of integrating retrieval and LLM training for medical question-answering systems.

JMLR: Joint Medical LLM and Retrieval Training for Enhancing Reasoning and Professional Question Answering Capability

TL;DR

The paper tackles hallucinations and domain knowledge gaps in medical QA by introducing JMLR, a framework that jointly trains a domain-specific retriever and an LLM to fetch clinically relevant guidelines during fine-tuning. It introduces an LLM-Rank loss to align the retriever with improvements in LLM performance, enabling end-to-end optimization and efficiency gains. On multiple medical QA benchmarks, JMLR-13B achieves state-of-the-art results, surpassing Meditron and RAG variants while drastically reducing training time. The work demonstrates that tightly coupled retrieval and language-model training can boost medical reasoning and factual accuracy, with thoughtful consideration of privacy, bias, and broader societal impacts for real-world deployment.

Abstract

Large Language Models (LLMs) have demonstrated a remarkable potential in medical knowledge acquisition and question-answering. However, LLMs can potentially hallucinate and yield factually incorrect outcomes, even with domain-specific pretraining. Previously, retrieval augmented generation (RAG) has limited success in addressing hallucinations. Unlike previous methods in RAG where the retrieval model was trained separately from the LLM, we introduce JMLR (for Jointly trains LLM and information Retrieval) during the fine-tuning phase. The synchronized training mechanism enhances JMLR's ability to retrieve clinical guidelines and leverage medical knowledge to reason and answer questions and reduces the demand for computational resources. We evaluated JMLR on the important medical question-answering application. Our experimental results demonstrate that JMLR-13B (70.5%) outperforms a previous state-of-the-art open-source model using conventional pre-training and fine-tuning Meditron-70B (68.9%) and Llama2-13B with RAG (67.7%) on a medical question-answering dataset. Comprehensive evaluations reveal JMLR-13B enhances reasoning quality and reduces hallucinations better than Claude3-Opus. Additionally, JMLR-13B (148 GPU hours) also trains much faster than Meditron-70B (42630 GPU hours). Through this work, we provide a new and efficient knowledge enhancement method for healthcare, demonstrating the potential of integrating retrieval and LLM training for medical question-answering systems.
Paper Structure (32 sections, 12 equations, 3 figures, 9 tables)

This paper contains 32 sections, 12 equations, 3 figures, 9 tables.

Figures (3)

  • Figure 1: JMLR achieved the highest average accuracy across the MMLU-Medical, MedMcQA, MedQA, and Amboss datasets, utilizing only 148 GPU hours.
  • Figure 2: Comparison between different domain adaptation methods: traditional domain pretraining method (left), RAG (middle), and JMLR (right). JMLR retrieves the documents to reduce the hallucination. Parameters are updated simultaneously for the retriever and large language models (LLM) models, leading the retriever to know which domain-specific document is helpful for LLM to give a reasonable answer.
  • Figure 3: The horizontal axis represents the number of documents retrieved in JMLR, while the vertical axis shows the accuracy of JMLR tested on the USMLE.