MediSwift: Efficient Sparse Pre-trained Biomedical Language Models
Vithursan Thangarasa, Mahmoud Salem, Shreyas Saxena, Kevin Leong, Joel Hestness, Sean Lie
TL;DR
MediSwift tackles the high computational cost of domain-specific LLMs by introducing unstructured weight sparsity during pre-training on biomedical text, followed by dense fine-tuning and soft prompting to recover accuracy. Trained on PubMed Central and PubMed data, MediSwift comes in Med, Large, and XL sizes and achieves up to $104.86\times 10^{9}$ total tokens processed with substantial FLOP reductions ($2$–$2.5$×) on the Cerebras CS-2. The dense XL model reaches a new state-of-the-art on PubMedQA at $76.8\%$ accuracy while being smaller than competing models, and the 50\% and 75\% sparse variants further improve efficiency-accuracy trade-offs across PubMedQA and HoC. The results show that sparse pre-training, combined with dense fine-tuning and soft prompting, is a practical approach to building high-performing, compute-efficient biomedical LLMs, with potential for further gains via dynamic sparse training and continued hardware-software co-design.
Abstract
Large language models (LLMs) are typically trained on general source data for various domains, but a recent surge in domain-specific LLMs has shown their potential to outperform general-purpose models in domain-specific tasks (e.g., biomedicine). Although domain-specific pre-training enhances efficiency and leads to smaller models, the computational costs of training these LLMs remain high, posing budgeting challenges. We introduce MediSwift, a suite of biomedical LMs that leverage sparse pre-training on domain-specific biomedical text data. By inducing up to 75% weight sparsity during the pre-training phase, MediSwift achieves a 2-2.5x reduction in training FLOPs. Notably, all sparse pre-training was performed on the Cerebras CS-2 system, which is specifically designed to realize the acceleration benefits from unstructured weight sparsity, thereby significantly enhancing the efficiency of the MediSwift models. Through subsequent dense fine-tuning and strategic soft prompting, MediSwift models outperform existing LLMs up to 7B parameters on biomedical tasks, setting new benchmarks w.r.t efficiency-accuracy on tasks such as PubMedQA. Our results show that sparse pre-training, along with dense fine-tuning and soft prompting, offers an effective method for creating high-performing, computationally efficient models in specialized domains.
