Table of Contents
Fetching ...

Autoregressive Sequence Modeling for 3D Medical Image Representation

Siwen Wang, Churan Wang, Fei Gao, Lixian Su, Fandong Zhang, Yizhou Wang, Yizhou Yu

TL;DR

This work tackles the challenge of learning generalizable representations for 3D medical images by introducing an autoregressive pre-training framework that serializes diverse scans into patch sequences shaped by spatial, contrast, and semantic correlations. By tokenizing patches into visual tokens and predicting the next token with an autoregressive objective, coupled with a random startup prefix attention, the method captures rich intra- and inter-scan context. The authors demonstrate state-of-the-art performance across nine public downstream tasks, including CT and MRI segmentation and COVID-19/lung nodule classification, with notable gains over both general SSL and medical SSL baselines and robust performance under limited annotation regime. The approach promises strong cross-modality generalization and a practical path toward label-efficient, clinically relevant 3D medical image analysis.

Abstract

Three-dimensional (3D) medical images, such as Computed Tomography (CT) and Magnetic Resonance Imaging (MRI), are essential for clinical applications. However, the need for diverse and comprehensive representations is particularly pronounced when considering the variability across different organs, diagnostic tasks, and imaging modalities. How to effectively interpret the intricate contextual information and extract meaningful insights from these images remains an open challenge to the community. While current self-supervised learning methods have shown potential, they often consider an image as a whole thereby overlooking the extensive, complex relationships among local regions from one or multiple images. In this work, we introduce a pioneering method for learning 3D medical image representations through an autoregressive pre-training framework. Our approach sequences various 3D medical images based on spatial, contrast, and semantic correlations, treating them as interconnected visual tokens within a token sequence. By employing an autoregressive sequence modeling task, we predict the next visual token in the sequence, which allows our model to deeply understand and integrate the contextual information inherent in 3D medical images. Additionally, we implement a random startup strategy to avoid overestimating token relationships and to enhance the robustness of learning. The effectiveness of our approach is demonstrated by the superior performance over others on nine downstream tasks in public datasets.

Autoregressive Sequence Modeling for 3D Medical Image Representation

TL;DR

This work tackles the challenge of learning generalizable representations for 3D medical images by introducing an autoregressive pre-training framework that serializes diverse scans into patch sequences shaped by spatial, contrast, and semantic correlations. By tokenizing patches into visual tokens and predicting the next token with an autoregressive objective, coupled with a random startup prefix attention, the method captures rich intra- and inter-scan context. The authors demonstrate state-of-the-art performance across nine public downstream tasks, including CT and MRI segmentation and COVID-19/lung nodule classification, with notable gains over both general SSL and medical SSL baselines and robust performance under limited annotation regime. The approach promises strong cross-modality generalization and a practical path toward label-efficient, clinically relevant 3D medical image analysis.

Abstract

Three-dimensional (3D) medical images, such as Computed Tomography (CT) and Magnetic Resonance Imaging (MRI), are essential for clinical applications. However, the need for diverse and comprehensive representations is particularly pronounced when considering the variability across different organs, diagnostic tasks, and imaging modalities. How to effectively interpret the intricate contextual information and extract meaningful insights from these images remains an open challenge to the community. While current self-supervised learning methods have shown potential, they often consider an image as a whole thereby overlooking the extensive, complex relationships among local regions from one or multiple images. In this work, we introduce a pioneering method for learning 3D medical image representations through an autoregressive pre-training framework. Our approach sequences various 3D medical images based on spatial, contrast, and semantic correlations, treating them as interconnected visual tokens within a token sequence. By employing an autoregressive sequence modeling task, we predict the next visual token in the sequence, which allows our model to deeply understand and integrate the contextual information inherent in 3D medical images. Additionally, we implement a random startup strategy to avoid overestimating token relationships and to enhance the robustness of learning. The effectiveness of our approach is demonstrated by the superior performance over others on nine downstream tasks in public datasets.
Paper Structure (24 sections, 2 equations, 2 figures, 6 tables)

This paper contains 24 sections, 2 equations, 2 figures, 6 tables.

Figures (2)

  • Figure 1: Overview of our Autoregressive Sequence Modeling approach for 3D Medical Images. The left purple box shows the transformation of one or more 3D medical images into a patch sequence with N patches, highlighting spatial, contrast, and semantic relationships within 3D data. In the orange box on the right, patches within the sequence are divided into visual tokens, which are then concatenated to form an ordered token sequence. During pre-training, the start of the token sequence $t_i$ is selected randomly to enhance learning robustness. At the bottom of the orange box, the schematic diagrams of the training mechanism demonstrate how our method leverages autoregressive modeling to predict subsequent tokens and integrate contextual information. The green box shows our method can be generalized to various downstream tasks in the fine-tuning stage.
  • Figure 2: Visualizations of the segmentation results for various organs and pathologies from CT scans and MRI based on our proposed method and compared baselines. Each row denotes a different task. Each column denotes a different method.