Nirvana: A Specialized Generalist Model With Task-Aware Memory Mechanism
Yuhua Jiang, Shuang Cheng, Yihao Liu, Ermo Hua, Che Jiang, Weigao Sun, Yu Cheng, Feifei Gao, Biqing Qi, Bowen Zhou
TL;DR
Nirvana addresses the limitation of traditional LLMs in specialized domains by introducing a task-aware memory mechanism that adapts on a per-sample basis. The framework combines a Task-Aware Memory Trigger for self-supervised, test-time fine-tuning with a Specialized Memory Updater that interpolates between SWA and Linear Attention, achieving linear-time memory complexity. Empirically, Nirvana matches or surpasses state-of-the-art general-language benchmarks and demonstrates superior MRI reconstruction from undersampled data, with the backbone remaining frozen and only lightweight codecs trained. This approach reduces the need for extensive domain-specific backbone retraining, enabling rapid adaptation to specialized tasks and suggesting broad practical impact in medical imaging and beyond.
Abstract
Specialized Generalist Models (SGMs) aim to preserve broad capabilities while achieving expert-level performance in target domains. However, traditional LLM structures including Transformer, Linear Attention, and hybrid models do not employ specialized memory mechanism guided by task information. In this paper, we present Nirvana, an SGM with specialized memory mechanism, linear time complexity, and test-time task information extraction. Besides, we propose the Task-Aware Memory Trigger ($\textit{Trigger}$) that flexibly adjusts memory mechanism based on the current task's requirements. In Trigger, each incoming sample is treated as a self-supervised fine-tuning task, enabling Nirvana to adapt its task-related parameters on the fly to domain shifts. We also design the Specialized Memory Updater ($\textit{Updater}$) that dynamically memorizes the context guided by Trigger. We conduct experiments on both general language tasks and specialized medical tasks. On a variety of natural language modeling benchmarks, Nirvana achieves competitive or superior results compared to the existing LLM structures. To prove the effectiveness of Trigger on specialized tasks, we test Nirvana's performance on a challenging medical task, i.e., Magnetic Resonance Imaging (MRI). We post-train frozen Nirvana backbone with lightweight codecs on paired electromagnetic signals and MRI images. Despite the frozen Nirvana backbone, Trigger guides the model to adapt to the MRI domain with the change of task-related parameters. Nirvana achieves higher-quality MRI reconstruction compared to conventional MRI models as well as the models with traditional LLMs' backbone, and can also generate accurate preliminary clinical reports accordingly.
