Table of Contents
Fetching ...

Primus: Enforcing Attention Usage for 3D Medical Image Segmentation

Tassilo Wald, Saikat Roy, Fabian Isensee, Constantin Ulrich, Sebastian Ziegler, Dasha Trofimova, Raphael Stock, Michael Baumgartner, Gregor Köhler, Klaus Maier-Hein

TL;DR

This work investigates why Transformer-based architectures underperform CNNs in 3D medical image segmentation and introduces Primus, the first pure Transformer model for this domain. By systematically deconstructing nine hybrid architectures, the authors show that non-Transformer parameters and architectural choices often drive performance, limiting Transformer effectiveness. Primus mitigates these issues by using high-resolution $8^3$ voxel tokens, 3D Rotary Positional Embeddings, SwiGLU MLP blocks, LayerScale, and a lightweight decoder to maximize attention-based learning with minimal convolution. Across multiple public datasets, Primus achieves competitive results with CNN baselines and outperforms several Transformer hybrids, marking a significant step toward Transformer-dominated 3D medical image segmentation and opening avenues for multi-modal integration and self-supervised pre-training.

Abstract

Transformers have achieved remarkable success across multiple fields, yet their impact on 3D medical image segmentation remains limited with convolutional networks still dominating major benchmarks. In this work, we a) analyze current Transformer-based segmentation models and identify critical shortcomings, particularly their over-reliance on convolutional blocks. Further, we demonstrate that in some architectures, performance is unaffected by the absence of the Transformer, thereby demonstrating their limited effectiveness. To address these challenges, we move away from hybrid architectures and b) introduce a fully Transformer-based segmentation architecture, termed Primus. Primus leverages high-resolution tokens, combined with advances in positional embeddings and block design, to maximally leverage its Transformer blocks. Through these adaptations Primus surpasses current Transformer-based methods and competes with state-of-the-art convolutional models on multiple public datasets. By doing so, we create the first pure Transformer architecture and take a significant step towards making Transformers state-of-the-art for 3D medical image segmentation.

Primus: Enforcing Attention Usage for 3D Medical Image Segmentation

TL;DR

This work investigates why Transformer-based architectures underperform CNNs in 3D medical image segmentation and introduces Primus, the first pure Transformer model for this domain. By systematically deconstructing nine hybrid architectures, the authors show that non-Transformer parameters and architectural choices often drive performance, limiting Transformer effectiveness. Primus mitigates these issues by using high-resolution voxel tokens, 3D Rotary Positional Embeddings, SwiGLU MLP blocks, LayerScale, and a lightweight decoder to maximize attention-based learning with minimal convolution. Across multiple public datasets, Primus achieves competitive results with CNN baselines and outperforms several Transformer hybrids, marking a significant step toward Transformer-dominated 3D medical image segmentation and opening avenues for multi-modal integration and self-supervised pre-training.

Abstract

Transformers have achieved remarkable success across multiple fields, yet their impact on 3D medical image segmentation remains limited with convolutional networks still dominating major benchmarks. In this work, we a) analyze current Transformer-based segmentation models and identify critical shortcomings, particularly their over-reliance on convolutional blocks. Further, we demonstrate that in some architectures, performance is unaffected by the absence of the Transformer, thereby demonstrating their limited effectiveness. To address these challenges, we move away from hybrid architectures and b) introduce a fully Transformer-based segmentation architecture, termed Primus. Primus leverages high-resolution tokens, combined with advances in positional embeddings and block design, to maximally leverage its Transformer blocks. Through these adaptations Primus surpasses current Transformer-based methods and competes with state-of-the-art convolutional models on multiple public datasets. By doing so, we create the first pure Transformer architecture and take a significant step towards making Transformers state-of-the-art for 3D medical image segmentation.

Paper Structure

This paper contains 58 sections, 2 equations, 8 figures, 14 tables.

Figures (8)

  • Figure 1: Effective Transformer-based networks have low UNet-index and high performance. In \ref{['fig:subfig1']}, we observe that existing architectures mostly do not outperform a similarly trained UNet, on 2 datasets: For TotalSegmentator-BTCV, 8 out of 9 and for KiTS19, all 9. Further, we demonstrate in \ref{['fig:subfig2']} on both datasets that 6 out of 9 architectures do not even show a 3% loss of performance ($\delta_\text{TR}$) on completely removing all Transformers. Primus is the only network competitive with nnUNet with a low UNet index.
  • Figure 2: Scaling Dataset size does not fix the challenges with Transformer-based representation learning. Increasing training data on TotalSegmentator-BTCV (1000 3D volumes) only seems to increase the gap between Transformer and no Transformer in 4 out of 9 architectures (UNETR, SwinUNETR, SwinUNet, TransFuse). As reference we include a default nnU-Net.
  • Figure 3: Primus is a Transformer-heavy architecture with limited convolution layers. The architecture extracts high-resolution 3D visual tokens through a single convolution layer with kernel size ($k \times k \times k$) and stride ($k\times k \times k$) through small $k$. Once in sequence format, it uses the Eva-02 fang2024eva Transformer architecture, featuring a Rotary Position Embedding (RoPE) adapted to 3D and the Eva-02 MLP Block. The lightweight decoder is composed of a sequence of Transposed Convolutions, reverting the tokenization, and represents the convolutional part of the network.
  • Figure 4: Segmentation performance pre-and-post Identity replacement of a Transformer module quantifies their importance. By replacing the entire Transformer block, including LayerNorm, Multi-Head Self-Attention or Shifted Window Multi-head Self-Attention, the influence of the entire Transformer within an architecture can be evaluated.
  • Figure 5: MICCAI challenges categorized by their task. Since a long time at least 50% of challenges only focus on semantic segmentation with other tasks being significantly less represented.
  • ...and 3 more figures