Graph Neural Networks: A suitable Alternative to MLPs in Latent 3D Medical Image Classification?
Johannes Kiechle, Daniel M. Lang, Stefan M. Fischer, Lina Felsner, Jan C. Peeken, Julia A. Schnabel
TL;DR
This work investigates replacing standard MLP prediction heads with graph neural networks for latent 3D medical image classification by building subject-level graphs from slice-wise DINOv2 features. Each volume yields 64 nodes, where node features are 384-d representations from 64 axial, coronal, and sagittal views produced by a frozen DINOv2 ViT, and eight graph-construction strategies pair with GraphSAGE or GAT convolutions to produce graph-level predictions. Across MedMNIST3D datasets, GNN heads often outperform MLP heads in AUROC and ACC while offering substantial runtime savings, though optimal topology and convolution type vary by dataset. The findings support GNNs as a viable, more efficient alternative to MLPs for latent 3D medical image classification, with robustness comparable to MLPs under perturbations and potential for further gains via adaptive graph learning.
Abstract
Recent studies have underscored the capabilities of natural imaging foundation models to serve as powerful feature extractors, even in a zero-shot setting for medical imaging data. Most commonly, a shallow multi-layer perceptron (MLP) is appended to the feature extractor to facilitate end-to-end learning and downstream prediction tasks such as classification, thus representing the de facto standard. However, as graph neural networks (GNNs) have become a practicable choice for various tasks in medical research in the recent past, we direct attention to the question of how effective GNNs are compared to MLP prediction heads for the task of 3D medical image classification, proposing them as a potential alternative. In our experiments, we devise a subject-level graph for each volumetric dataset instance. Therein latent representations of all slices in the volume, encoded through a DINOv2 pretrained vision transformer (ViT), constitute the nodes and their respective node features. We use public datasets to compare the classification heads numerically and evaluate various graph construction and graph convolution methods in our experiments. Our findings show enhancements of the GNN in classification performance and substantial improvements in runtime compared to an MLP prediction head. Additional robustness evaluations further validate the promising performance of the GNN, promoting them as a suitable alternative to traditional MLP classification heads. Our code is publicly available at: https://github.com/compai-lab/2024-miccai-grail-kiechle
