Table of Contents
Fetching ...

SegStitch: Multidimensional Transformer for Robust and Efficient Medical Imaging Segmentation

Shengbo Tan, Zeyu Zhang, Ying Cai, Daji Ergu, Lin Wu, Binbin Hu, Pengzhang Yu, Yang Zhao

TL;DR

SegStitch addresses the challenge of efficient, robust 3D medical image segmentation by integrating transformers with neural memory ordinary differential equations (nmODE) and by using axial patches with position-aware tokenization. The architecture employs shared queries for dual fine- and coarse-grained self-attention to capture long-range dependencies while maintaining linear-complexity, and it embeds an nmODE block in the decoder to improve stability and feature integration. Empirical results on multiple public datasets (including Synapse and ACDC, with references to BTCV in the abstract) show significant accuracy gains over state-of-the-art methods, along with substantial parameter and FLOP reductions relative to UNETR. The work demonstrates strong potential for real-world clinical deployment due to improved segmentation performance and increased computational efficiency, supported by comprehensive ablations and robustness analyses.

Abstract

Medical imaging segmentation plays a significant role in the automatic recognition and analysis of lesions. State-of-the-art methods, particularly those utilizing transformers, have been prominently adopted in 3D semantic segmentation due to their superior performance in scalability and generalizability. However, plain vision transformers encounter challenges due to their neglect of local features and their high computational complexity. To address these challenges, we introduce three key contributions: Firstly, we proposed SegStitch, an innovative architecture that integrates transformers with denoising ODE blocks. Instead of taking whole 3D volumes as inputs, we adapt axial patches and customize patch-wise queries to ensure semantic consistency. Additionally, we conducted extensive experiments on the BTCV and ACDC datasets, achieving improvements up to 11.48% and 6.71% respectively in mDSC, compared to state-of-the-art methods. Lastly, our proposed method demonstrates outstanding efficiency, reducing the number of parameters by 36.7% and the number of FLOPS by 10.7% compared to UNETR. This advancement holds promising potential for adapting our method to real-world clinical practice. The code will be available at https://github.com/goblin327/SegStitch

SegStitch: Multidimensional Transformer for Robust and Efficient Medical Imaging Segmentation

TL;DR

SegStitch addresses the challenge of efficient, robust 3D medical image segmentation by integrating transformers with neural memory ordinary differential equations (nmODE) and by using axial patches with position-aware tokenization. The architecture employs shared queries for dual fine- and coarse-grained self-attention to capture long-range dependencies while maintaining linear-complexity, and it embeds an nmODE block in the decoder to improve stability and feature integration. Empirical results on multiple public datasets (including Synapse and ACDC, with references to BTCV in the abstract) show significant accuracy gains over state-of-the-art methods, along with substantial parameter and FLOP reductions relative to UNETR. The work demonstrates strong potential for real-world clinical deployment due to improved segmentation performance and increased computational efficiency, supported by comprehensive ablations and robustness analyses.

Abstract

Medical imaging segmentation plays a significant role in the automatic recognition and analysis of lesions. State-of-the-art methods, particularly those utilizing transformers, have been prominently adopted in 3D semantic segmentation due to their superior performance in scalability and generalizability. However, plain vision transformers encounter challenges due to their neglect of local features and their high computational complexity. To address these challenges, we introduce three key contributions: Firstly, we proposed SegStitch, an innovative architecture that integrates transformers with denoising ODE blocks. Instead of taking whole 3D volumes as inputs, we adapt axial patches and customize patch-wise queries to ensure semantic consistency. Additionally, we conducted extensive experiments on the BTCV and ACDC datasets, achieving improvements up to 11.48% and 6.71% respectively in mDSC, compared to state-of-the-art methods. Lastly, our proposed method demonstrates outstanding efficiency, reducing the number of parameters by 36.7% and the number of FLOPS by 10.7% compared to UNETR. This advancement holds promising potential for adapting our method to real-world clinical practice. The code will be available at https://github.com/goblin327/SegStitch
Paper Structure (25 sections, 8 equations, 11 figures, 7 tables)

This paper contains 25 sections, 8 equations, 11 figures, 7 tables.

Figures (11)

  • Figure 1: Relationship between model parameter count, computational complexity, and Dice similarity coefficient. The size of the spheres indicates the model parameter count. Compared to other models, our SegStitch achieves the highest mDSC while maintaining smaller model size and lower computational complexity.
  • Figure 2: SegStitch Network Encoder Structure Diagram. The entire network comprises four layers, each containing a downsampling module and stages. The downsampling module consists of convolution and Group Normalization, while the stages consist of three Transformer Blocks. Each Transformer Block computes different patches, together forming the complete downsampling module.
  • Figure 3: Overall Architecture. The SegStitch method utilizes a hierarchical encoder-decoder structure. The output of the downsampling modules is passed to the decoder through skip connections, and each decoder module generates the final segmentation mask using ODE blocks.
  • Figure 4: Position-weighted token. we generate a set of learnable parameters $\text{Token}_{wt}$ with a size of 1×d (where d represents the total length after flattening the 3D patch). These parameters are used to learn the spatial positional information of the 3D patch.
  • Figure 5: The shared query structure of the internal self-attention mechanism in SegStitch.the input feature map $x$ is fed into the fine-grained and coarse-grained attention modules of the SegStitch. The weights of the linear layers for $Q$ are shared between the two attention modules. Additionally, using the global learnable parameter matrix $M$, new parameters $\text{K}_R$ and $\text{V}_RF$ are obtained by multiplying $M$ with $\text{K}_{fine}$ and $\text{V}_{fine}$ respectively, resulting in a new set of QKV combinations. Thereby, enabling different attention tasks for each self-attention module.
  • ...and 6 more figures