Autoregressive Sequence Modeling for 3D Medical Image Representation
Siwen Wang, Churan Wang, Fei Gao, Lixian Su, Fandong Zhang, Yizhou Wang, Yizhou Yu
TL;DR
This work tackles the challenge of learning generalizable representations for 3D medical images by introducing an autoregressive pre-training framework that serializes diverse scans into patch sequences shaped by spatial, contrast, and semantic correlations. By tokenizing patches into visual tokens and predicting the next token with an autoregressive objective, coupled with a random startup prefix attention, the method captures rich intra- and inter-scan context. The authors demonstrate state-of-the-art performance across nine public downstream tasks, including CT and MRI segmentation and COVID-19/lung nodule classification, with notable gains over both general SSL and medical SSL baselines and robust performance under limited annotation regime. The approach promises strong cross-modality generalization and a practical path toward label-efficient, clinically relevant 3D medical image analysis.
Abstract
Three-dimensional (3D) medical images, such as Computed Tomography (CT) and Magnetic Resonance Imaging (MRI), are essential for clinical applications. However, the need for diverse and comprehensive representations is particularly pronounced when considering the variability across different organs, diagnostic tasks, and imaging modalities. How to effectively interpret the intricate contextual information and extract meaningful insights from these images remains an open challenge to the community. While current self-supervised learning methods have shown potential, they often consider an image as a whole thereby overlooking the extensive, complex relationships among local regions from one or multiple images. In this work, we introduce a pioneering method for learning 3D medical image representations through an autoregressive pre-training framework. Our approach sequences various 3D medical images based on spatial, contrast, and semantic correlations, treating them as interconnected visual tokens within a token sequence. By employing an autoregressive sequence modeling task, we predict the next visual token in the sequence, which allows our model to deeply understand and integrate the contextual information inherent in 3D medical images. Additionally, we implement a random startup strategy to avoid overestimating token relationships and to enhance the robustness of learning. The effectiveness of our approach is demonstrated by the superior performance over others on nine downstream tasks in public datasets.
