Table of Contents
Fetching ...

Enhancing 3D Transformer Segmentation Model for Medical Image with Token-level Representation Learning

Xinrong Hu, Dewen Zeng, Yawen Wu, Xueyang Li, Yiyu Shi

TL;DR

This work tackles pre-training for 3D medical image segmentation by introducing token-level representation learning on Swin Transformer encoders. It proposes SimPROT, a patch-wise contrastive framework with a novel rotate-and-restore mechanism to prevent representation collapse, plus a weighted contrastive term to discriminate tokens across volumes. Across Synapse and Brain Tumor datasets, SimPROT-W consistently outperforms self-supervised baselines (e.g., MAE, SimMIM, SimCLR, BYOL) and transformer-specific pretexts, with nnFormer achieving the best overall results. The approach is data-efficient and architecture-agnostic, improving segmentation performance without requiring external unlabeled datasets, thereby enhancing practical applicability in medical imaging.

Abstract

In the field of medical images, although various works find Swin Transformer has promising effectiveness on pixelwise dense prediction, whether pre-training these models without using extra dataset can further boost the performance for the downstream semantic segmentation remains unexplored.Applications of previous representation learning methods are hindered by the limited number of 3D volumes and high computational cost. In addition, most of pretext tasks designed specifically for Transformer are not applicable to hierarchical structure of Swin Transformer. Thus, this work proposes a token-level representation learning loss that maximizes agreement between token embeddings from different augmented views individually instead of volume-level global features. Moreover, we identify a potential representation collapse exclusively caused by this new loss. To prevent collapse, we invent a simple "rotate-and-restore" mechanism, which rotates and flips one augmented view of input volume, and later restores the order of tokens in the feature maps. We also modify the contrastive loss to address the discrimination between tokens at the same position but from different volumes. We test our pre-training scheme on two public medical segmentation datasets, and the results on the downstream segmentation task show more improvement of our methods than other state-of-the-art pre-trainig methods.

Enhancing 3D Transformer Segmentation Model for Medical Image with Token-level Representation Learning

TL;DR

This work tackles pre-training for 3D medical image segmentation by introducing token-level representation learning on Swin Transformer encoders. It proposes SimPROT, a patch-wise contrastive framework with a novel rotate-and-restore mechanism to prevent representation collapse, plus a weighted contrastive term to discriminate tokens across volumes. Across Synapse and Brain Tumor datasets, SimPROT-W consistently outperforms self-supervised baselines (e.g., MAE, SimMIM, SimCLR, BYOL) and transformer-specific pretexts, with nnFormer achieving the best overall results. The approach is data-efficient and architecture-agnostic, improving segmentation performance without requiring external unlabeled datasets, thereby enhancing practical applicability in medical imaging.

Abstract

In the field of medical images, although various works find Swin Transformer has promising effectiveness on pixelwise dense prediction, whether pre-training these models without using extra dataset can further boost the performance for the downstream semantic segmentation remains unexplored.Applications of previous representation learning methods are hindered by the limited number of 3D volumes and high computational cost. In addition, most of pretext tasks designed specifically for Transformer are not applicable to hierarchical structure of Swin Transformer. Thus, this work proposes a token-level representation learning loss that maximizes agreement between token embeddings from different augmented views individually instead of volume-level global features. Moreover, we identify a potential representation collapse exclusively caused by this new loss. To prevent collapse, we invent a simple "rotate-and-restore" mechanism, which rotates and flips one augmented view of input volume, and later restores the order of tokens in the feature maps. We also modify the contrastive loss to address the discrimination between tokens at the same position but from different volumes. We test our pre-training scheme on two public medical segmentation datasets, and the results on the downstream segmentation task show more improvement of our methods than other state-of-the-art pre-trainig methods.
Paper Structure (12 sections, 3 equations, 6 figures, 4 tables)

This paper contains 12 sections, 3 equations, 6 figures, 4 tables.

Figures (6)

  • Figure 1: Illustration of three popular pre-training methods for Transformer and our SimPROT when the encoder backbone is changed to Swin Transformer. "$\times$" means the method is not feasible, while "✓" means it is applicable. (a) BEITbao2021beit has unmatched number of patches between output and visual tokens, and for (b) MAEhe2022masked, volumes after masking out are not valid input for Swin Transformer. Notice that (d) is a simplified version of our method with SimCLRchen2020simple framework, some details are not present.
  • Figure 2: Detailed workflow of SimTROT, our proposed token-wise representation embedded in a SimCLR framework. $\Gamma_{t}$ and $\Gamma_{s}$ represent random texture transformations and random spatial transformations (rotation and flip). A Swin encoder projects different views $\Tilde{x}^{i}$ and $\hat{x}^{i}$ into 3D feature maps. Each cell in the cube is the feature representation of a group of neighboring input patches. A contrastive loss is then employed to learn useful features for the downstream task.
  • Figure 3: Visualization of representation collapse in the feature space. There are four different feature maps, and each feature maps contains $4*4*4$ embeddings. (a) displays a pattern where all four feature maps are exactly the same and embeddings at the same coordinates are gathered together. In contrast, the ideal learned representations should be like (b), where only embeddings of different views from the same volume are clustered based on position in the feature map, and others are separated. Best viewed in color.
  • Figure 4: Visualization of segmentation results of different methods on Synapse and Brain Tumor dataset. The above two rows are results from multi-organ segmentation, and the two rows at the bottom are from tumor segmentation task. "SU" is short for Swin UNETR, and "nn" is short for nnFormer. Best viewed in color.
  • Figure 5: Performance of training from scratch and our SimTROT on Synapse dataset given different percentages of labeled training data.
  • ...and 1 more figures