Table of Contents
Fetching ...

MaskMed: Decoupled Mask and Class Prediction for Medical Image Segmentation

Bin Xie, Gady Agam

TL;DR

MaskMed addresses the rigidity of traditional point-wise segmentation heads in 3D medical image segmentation by decoupling mask generation from class prediction and enabling a bipartite matching supervision. It introduces a Masked Multi-Scale Segmentation Head with a shared query mechanism and a Full-Scale Aware Deformable Transformer (FSAD-Transformer) to fuse full-scale encoder features efficiently. The approach achieves state-of-the-art results on AMOS 2022 and BTCV, notably reaching a Dice of $0.913$ on AMOS 2022 versus nnU-Net at $0.893$, and achieving substantial gains on BTCV. This work offers improved flexibility, interpretability, and generalization, and suggests future extensions such as increasing object queries and exploring multi-instance segmentation in 3D medical imaging.

Abstract

Medical image segmentation typically adopts a point-wise convolutional segmentation head to predict dense labels, where each output channel is heuristically tied to a specific class. This rigid design limits both feature sharing and semantic generalization. In this work, we propose a unified decoupled segmentation head that separates multi-class prediction into class-agnostic mask prediction and class label prediction using shared object queries. Furthermore, we introduce a Full-Scale Aware Deformable Transformer module that enables low-resolution encoder features to attend across full-resolution encoder features via deformable attention, achieving memory-efficient and spatially aligned full-scale fusion. Our proposed method, named MaskMed, achieves state-of-the-art performance, surpassing nnUNet by +2.0% Dice on AMOS 2022 and +6.9% Dice on BTCV.

MaskMed: Decoupled Mask and Class Prediction for Medical Image Segmentation

TL;DR

MaskMed addresses the rigidity of traditional point-wise segmentation heads in 3D medical image segmentation by decoupling mask generation from class prediction and enabling a bipartite matching supervision. It introduces a Masked Multi-Scale Segmentation Head with a shared query mechanism and a Full-Scale Aware Deformable Transformer (FSAD-Transformer) to fuse full-scale encoder features efficiently. The approach achieves state-of-the-art results on AMOS 2022 and BTCV, notably reaching a Dice of on AMOS 2022 versus nnU-Net at , and achieving substantial gains on BTCV. This work offers improved flexibility, interpretability, and generalization, and suggests future extensions such as increasing object queries and exploring multi-instance segmentation in 3D medical imaging.

Abstract

Medical image segmentation typically adopts a point-wise convolutional segmentation head to predict dense labels, where each output channel is heuristically tied to a specific class. This rigid design limits both feature sharing and semantic generalization. In this work, we propose a unified decoupled segmentation head that separates multi-class prediction into class-agnostic mask prediction and class label prediction using shared object queries. Furthermore, we introduce a Full-Scale Aware Deformable Transformer module that enables low-resolution encoder features to attend across full-resolution encoder features via deformable attention, achieving memory-efficient and spatially aligned full-scale fusion. Our proposed method, named MaskMed, achieves state-of-the-art performance, surpassing nnUNet by +2.0% Dice on AMOS 2022 and +6.9% Dice on BTCV.

Paper Structure

This paper contains 22 sections, 3 equations, 7 figures, 7 tables.

Figures (7)

  • Figure 1: Illustration of different segmentation head architectures evolving from the conventional UNet-based design (a, b) to our proposed decoupled mask and class embedding framework (f).
  • Figure 2: Model Architecture Overview. (a) Our full model adopts an encoder-decoder framework with a Full-Scale Aware Deformable Transformer (FSAD-Transformer) module bridging multi-scale encoder features and decoder inputs. (b) The Masked Multi-Scale Segmentation Head uses a shared query set to decode both mask and class embeddings via transformer layers. (c) The FSAD-Transformer allows deformable attention across the full feature hierarchy, using multi-scale queries and full-resolution value features.
  • Figure 3: Visualization of Class Embedding for different segmentation heads.
  • Figure 4: Visualization of Class Emb
  • Figure 5: Visualization of Mask Emb
  • ...and 2 more figures