MaskMed: Decoupled Mask and Class Prediction for Medical Image Segmentation
Bin Xie, Gady Agam
TL;DR
MaskMed addresses the rigidity of traditional point-wise segmentation heads in 3D medical image segmentation by decoupling mask generation from class prediction and enabling a bipartite matching supervision. It introduces a Masked Multi-Scale Segmentation Head with a shared query mechanism and a Full-Scale Aware Deformable Transformer (FSAD-Transformer) to fuse full-scale encoder features efficiently. The approach achieves state-of-the-art results on AMOS 2022 and BTCV, notably reaching a Dice of $0.913$ on AMOS 2022 versus nnU-Net at $0.893$, and achieving substantial gains on BTCV. This work offers improved flexibility, interpretability, and generalization, and suggests future extensions such as increasing object queries and exploring multi-instance segmentation in 3D medical imaging.
Abstract
Medical image segmentation typically adopts a point-wise convolutional segmentation head to predict dense labels, where each output channel is heuristically tied to a specific class. This rigid design limits both feature sharing and semantic generalization. In this work, we propose a unified decoupled segmentation head that separates multi-class prediction into class-agnostic mask prediction and class label prediction using shared object queries. Furthermore, we introduce a Full-Scale Aware Deformable Transformer module that enables low-resolution encoder features to attend across full-resolution encoder features via deformable attention, achieving memory-efficient and spatially aligned full-scale fusion. Our proposed method, named MaskMed, achieves state-of-the-art performance, surpassing nnUNet by +2.0% Dice on AMOS 2022 and +6.9% Dice on BTCV.
