Table of Contents
Fetching ...

MediSwift: Efficient Sparse Pre-trained Biomedical Language Models

Vithursan Thangarasa, Mahmoud Salem, Shreyas Saxena, Kevin Leong, Joel Hestness, Sean Lie

TL;DR

MediSwift tackles the high computational cost of domain-specific LLMs by introducing unstructured weight sparsity during pre-training on biomedical text, followed by dense fine-tuning and soft prompting to recover accuracy. Trained on PubMed Central and PubMed data, MediSwift comes in Med, Large, and XL sizes and achieves up to $104.86\times 10^{9}$ total tokens processed with substantial FLOP reductions ($2$–$2.5$×) on the Cerebras CS-2. The dense XL model reaches a new state-of-the-art on PubMedQA at $76.8\%$ accuracy while being smaller than competing models, and the 50\% and 75\% sparse variants further improve efficiency-accuracy trade-offs across PubMedQA and HoC. The results show that sparse pre-training, combined with dense fine-tuning and soft prompting, is a practical approach to building high-performing, compute-efficient biomedical LLMs, with potential for further gains via dynamic sparse training and continued hardware-software co-design.

Abstract

Large language models (LLMs) are typically trained on general source data for various domains, but a recent surge in domain-specific LLMs has shown their potential to outperform general-purpose models in domain-specific tasks (e.g., biomedicine). Although domain-specific pre-training enhances efficiency and leads to smaller models, the computational costs of training these LLMs remain high, posing budgeting challenges. We introduce MediSwift, a suite of biomedical LMs that leverage sparse pre-training on domain-specific biomedical text data. By inducing up to 75% weight sparsity during the pre-training phase, MediSwift achieves a 2-2.5x reduction in training FLOPs. Notably, all sparse pre-training was performed on the Cerebras CS-2 system, which is specifically designed to realize the acceleration benefits from unstructured weight sparsity, thereby significantly enhancing the efficiency of the MediSwift models. Through subsequent dense fine-tuning and strategic soft prompting, MediSwift models outperform existing LLMs up to 7B parameters on biomedical tasks, setting new benchmarks w.r.t efficiency-accuracy on tasks such as PubMedQA. Our results show that sparse pre-training, along with dense fine-tuning and soft prompting, offers an effective method for creating high-performing, computationally efficient models in specialized domains.

MediSwift: Efficient Sparse Pre-trained Biomedical Language Models

TL;DR

MediSwift tackles the high computational cost of domain-specific LLMs by introducing unstructured weight sparsity during pre-training on biomedical text, followed by dense fine-tuning and soft prompting to recover accuracy. Trained on PubMed Central and PubMed data, MediSwift comes in Med, Large, and XL sizes and achieves up to total tokens processed with substantial FLOP reductions (×) on the Cerebras CS-2. The dense XL model reaches a new state-of-the-art on PubMedQA at accuracy while being smaller than competing models, and the 50\% and 75\% sparse variants further improve efficiency-accuracy trade-offs across PubMedQA and HoC. The results show that sparse pre-training, combined with dense fine-tuning and soft prompting, is a practical approach to building high-performing, compute-efficient biomedical LLMs, with potential for further gains via dynamic sparse training and continued hardware-software co-design.

Abstract

Large language models (LLMs) are typically trained on general source data for various domains, but a recent surge in domain-specific LLMs has shown their potential to outperform general-purpose models in domain-specific tasks (e.g., biomedicine). Although domain-specific pre-training enhances efficiency and leads to smaller models, the computational costs of training these LLMs remain high, posing budgeting challenges. We introduce MediSwift, a suite of biomedical LMs that leverage sparse pre-training on domain-specific biomedical text data. By inducing up to 75% weight sparsity during the pre-training phase, MediSwift achieves a 2-2.5x reduction in training FLOPs. Notably, all sparse pre-training was performed on the Cerebras CS-2 system, which is specifically designed to realize the acceleration benefits from unstructured weight sparsity, thereby significantly enhancing the efficiency of the MediSwift models. Through subsequent dense fine-tuning and strategic soft prompting, MediSwift models outperform existing LLMs up to 7B parameters on biomedical tasks, setting new benchmarks w.r.t efficiency-accuracy on tasks such as PubMedQA. Our results show that sparse pre-training, along with dense fine-tuning and soft prompting, offers an effective method for creating high-performing, computationally efficient models in specialized domains.
Paper Structure (29 sections, 4 equations, 3 figures, 7 tables)

This paper contains 29 sections, 4 equations, 3 figures, 7 tables.

Figures (3)

  • Figure 1: Comparison of Model Size vs. PubMedQA Accuracy in the Reasoning-Required Setting: Our dense and sparse MediSwift models noticeably outperform other fine-tuned language models $\leq$ 7B parameters, improving the efficiency-accuracy pareto frontier. In particular, MediSwift-XL (1.21B) achieves new state-of-the-art 76.8% accuracy at this size (i.e., being 5.8x smaller than PMC-LlaMA). In addition, sparse pre-trained MediSwift-XL models at $s \in \{50\%, 75\%\}$ outperform other models at similar or larger size. Additional details are provided in Table \ref{['tab:mediswift_pretrain_losses']} and \ref{['tab:mediswift_pubmed']}.
  • Figure 2: Comparison of pre-training loss curves for MediSwift models: MediSwift-XL's training loss reveals that at 50% sparsity, the model's performance closely mirrors that of its dense variant, with negligible effects on training loss. At 75% sparsity, although the gap in training loss widens, the sparse MediSwift-XL still outperforms the dense MediSwift-Med, showcasing efficient learning even at higher sparsity levels.
  • Figure 3: Comparison of measured speedup versus theoretical speedup for GPT-3 layer 12k $\times$ 12k matrix multiplication (MatMul) on the Cerebras CS-2 system at various sparsity levels. This graph illustrates the efficiency gains achieved through sparse computation, highlighting the real-world performance relative to theoretical predictions.