Table of Contents
Fetching ...

MiM: Mask in Mask Self-Supervised Pre-Training for 3D Medical Image Analysis

Jiaxin Zhuang, Linshan Wu, Qiong Wang, Peng Fei, Varut Vardhanabhuti, Lin Luo, Hao Chen

TL;DR

MiM introduces a hierarchical Mask in Mask pre-training framework for 3D medical image analysis, addressing single-scale MAE limitations by learning multi-scale, cross-level representations with a three-level volume strategy. The method combines multi-level reconstruction with cross-level alignment and deploys a 3D hybrid backbone to efficiently process unmasked tokens, achieving state-of-the-art results on numerous segmentation and classification benchmarks. Scaling pre-training data to 10k unlabeled volumes consistently improves performance, and cross-modality transfer from CT to MRI demonstrates MiM's generalizability. Overall, MiM highlights the importance of multi-scale, hierarchical SSL and large-scale pre-training for healthcare foundation models in 3D medical imaging.

Abstract

The Vision Transformer (ViT) has demonstrated remarkable performance in Self-Supervised Learning (SSL) for 3D medical image analysis. Masked AutoEncoder (MAE) for feature pre-training can further unleash the potential of ViT on various medical vision tasks. However, due to large spatial sizes with much higher dimensions of 3D medical images, the lack of hierarchical design for MAE may hinder the performance of downstream tasks. In this paper, we propose a novel \textit{Mask in Mask (MiM)} pre-training framework for 3D medical images, which aims to advance MAE by learning discriminative representation from hierarchical visual tokens across varying scales. We introduce multiple levels of granularity for masked inputs from the volume, which are then reconstructed simultaneously ranging at both fine and coarse levels. Additionally, a cross-level alignment mechanism is applied to adjacent level volumes to enforce anatomical similarity hierarchically. Furthermore, we adopt a hybrid backbone to enhance the hierarchical representation learning efficiently during the pre-training. MiM was pre-trained on a large scale of available 3D volumetric images, \textit{i.e.,} Computed Tomography (CT) images containing various body parts. Extensive experiments on thirteen public datasets demonstrate the superiority of MiM over other SSL methods in organ/lesion/tumor segmentation and disease classification. We further scale up the MiM to large pre-training datasets with more than 10k volumes, showing that large-scale pre-training can further enhance the performance of downstream tasks. The improvement also concluded that the research community should pay more attention to the scale of the pre-training dataset towards the healthcare foundation model for 3D medical images.

MiM: Mask in Mask Self-Supervised Pre-Training for 3D Medical Image Analysis

TL;DR

MiM introduces a hierarchical Mask in Mask pre-training framework for 3D medical image analysis, addressing single-scale MAE limitations by learning multi-scale, cross-level representations with a three-level volume strategy. The method combines multi-level reconstruction with cross-level alignment and deploys a 3D hybrid backbone to efficiently process unmasked tokens, achieving state-of-the-art results on numerous segmentation and classification benchmarks. Scaling pre-training data to 10k unlabeled volumes consistently improves performance, and cross-modality transfer from CT to MRI demonstrates MiM's generalizability. Overall, MiM highlights the importance of multi-scale, hierarchical SSL and large-scale pre-training for healthcare foundation models in 3D medical imaging.

Abstract

The Vision Transformer (ViT) has demonstrated remarkable performance in Self-Supervised Learning (SSL) for 3D medical image analysis. Masked AutoEncoder (MAE) for feature pre-training can further unleash the potential of ViT on various medical vision tasks. However, due to large spatial sizes with much higher dimensions of 3D medical images, the lack of hierarchical design for MAE may hinder the performance of downstream tasks. In this paper, we propose a novel \textit{Mask in Mask (MiM)} pre-training framework for 3D medical images, which aims to advance MAE by learning discriminative representation from hierarchical visual tokens across varying scales. We introduce multiple levels of granularity for masked inputs from the volume, which are then reconstructed simultaneously ranging at both fine and coarse levels. Additionally, a cross-level alignment mechanism is applied to adjacent level volumes to enforce anatomical similarity hierarchically. Furthermore, we adopt a hybrid backbone to enhance the hierarchical representation learning efficiently during the pre-training. MiM was pre-trained on a large scale of available 3D volumetric images, \textit{i.e.,} Computed Tomography (CT) images containing various body parts. Extensive experiments on thirteen public datasets demonstrate the superiority of MiM over other SSL methods in organ/lesion/tumor segmentation and disease classification. We further scale up the MiM to large pre-training datasets with more than 10k volumes, showing that large-scale pre-training can further enhance the performance of downstream tasks. The improvement also concluded that the research community should pay more attention to the scale of the pre-training dataset towards the healthcare foundation model for 3D medical images.
Paper Structure (22 sections, 10 equations, 10 figures, 14 tables)

This paper contains 22 sections, 10 equations, 10 figures, 14 tables.

Figures (10)

  • Figure 1: Different SSL for 3D medical image analysis. Current Masked Image Modeling methods for 3D medical images primarily (a) rely on pretext tasks e.g., inpainting, at a single level, utilizing hybrid transformers to incorporate all tokens or (b) employ an MAE that reconstructs at a single level using unmasked tokens. In contrast, (c) we observe that 3D medical images inherently exhibit hierarchical properties. Thus, our Mask in Mask (MiM) framework aims to encode multi-level 3D medical image learning across hierarchical visual tokens at various scales through multi-level reconstruction and cross-level alignment (we set the number of level $L$ to 3 in this figure). Additionally, our framework employs a hybrid transformer while only using unmasked tokens.
  • Figure 2: The overall view of our MiM pre-training framework. The level $L$ is set to 3 for better illustration. We first conduct the process of multi-level masked volume generation. The multi-level reconstruction module is responsible for reconstructing the masked volumes at different levels. The cross-level alignment module aligns representations of volumes between volumes from adjacent levels, aiming to enforce anatomical similarity hierarchically.
  • Figure 3: Illustration of the multi-level masked volume generation. Slices drawn from the 3D medical images are shown in the figure only for better illustration. Level-(l+1) volume $x^{l+1}$ is randomly sampled from the masked patches of Level-l volume $x^l$ (i.e., patch with red box).
  • Figure 4: Illustration of cross-level alignment module. We set level $L$ to 2 for better illustration.
  • Figure 5: Architecture of the FPN adapted from MCMAE gao2022convmae, illustrating multi-scale feature extraction and fusion through hierarchical feature maps.
  • ...and 5 more figures