Memorizing SAM: 3D Medical Segment Anything Model with Memorizing Transformer
Xinyuan Shao, Yiqing Shen, Mathias Unberath
TL;DR
Memorizing SAM addresses the challenge of applying Segment Anything Models to 3D medical segmentation by introducing a memorizing Transformer plug-in that leverages pre-computed, high-quality internal representations stored as external memory. Inference retrieves memories via a kNN search and fuses memory attention with local attention, using a preset local ratio to balance contributions. On a 33-class subset of TotalSegmentator, Memorizing SAM achieves an average Dice improvement of $11.36\%$ over FastSAM3D, particularly in un-fine-tuned scenarios, while incurring only a modest overhead in inference time. This approach offers a practical path to enhancing foundation-model segmentation in medical imaging with limited fine-tuning data and publicly available code.
Abstract
Segment Anything Models (SAMs) have gained increasing attention in medical image analysis due to their zero-shot generalization capability in segmenting objects of unseen classes and domains when provided with appropriate user prompts. Addressing this performance gap is important to fully leverage the pre-trained weights of SAMs, particularly in the domain of volumetric medical image segmentation, where accuracy is important but well-annotated 3D medical data for fine-tuning is limited. In this work, we investigate whether introducing the memory mechanism as a plug-in, specifically the ability to memorize and recall internal representations of past inputs, can improve the performance of SAM with limited computation cost. To this end, we propose Memorizing SAM, a novel 3D SAM architecture incorporating a memory Transformer as a plug-in. Unlike conventional memorizing Transformers that save the internal representation during training or inference, our Memorizing SAM utilizes existing highly accurate internal representation as the memory source to ensure the quality of memory. We evaluate the performance of Memorizing SAM in 33 categories from the TotalSegmentator dataset, which indicates that Memorizing SAM can outperform state-of-the-art 3D SAM variant i.e., FastSAM3D with an average Dice increase of 11.36% at the cost of only 4.38 millisecond increase in inference time. The source code is publicly available at https://github.com/swedfr/memorizingSAM
