Enhancing 3D Transformer Segmentation Model for Medical Image with Token-level Representation Learning
Xinrong Hu, Dewen Zeng, Yawen Wu, Xueyang Li, Yiyu Shi
TL;DR
This work tackles pre-training for 3D medical image segmentation by introducing token-level representation learning on Swin Transformer encoders. It proposes SimPROT, a patch-wise contrastive framework with a novel rotate-and-restore mechanism to prevent representation collapse, plus a weighted contrastive term to discriminate tokens across volumes. Across Synapse and Brain Tumor datasets, SimPROT-W consistently outperforms self-supervised baselines (e.g., MAE, SimMIM, SimCLR, BYOL) and transformer-specific pretexts, with nnFormer achieving the best overall results. The approach is data-efficient and architecture-agnostic, improving segmentation performance without requiring external unlabeled datasets, thereby enhancing practical applicability in medical imaging.
Abstract
In the field of medical images, although various works find Swin Transformer has promising effectiveness on pixelwise dense prediction, whether pre-training these models without using extra dataset can further boost the performance for the downstream semantic segmentation remains unexplored.Applications of previous representation learning methods are hindered by the limited number of 3D volumes and high computational cost. In addition, most of pretext tasks designed specifically for Transformer are not applicable to hierarchical structure of Swin Transformer. Thus, this work proposes a token-level representation learning loss that maximizes agreement between token embeddings from different augmented views individually instead of volume-level global features. Moreover, we identify a potential representation collapse exclusively caused by this new loss. To prevent collapse, we invent a simple "rotate-and-restore" mechanism, which rotates and flips one augmented view of input volume, and later restores the order of tokens in the feature maps. We also modify the contrastive loss to address the discrimination between tokens at the same position but from different volumes. We test our pre-training scheme on two public medical segmentation datasets, and the results on the downstream segmentation task show more improvement of our methods than other state-of-the-art pre-trainig methods.
