MTMed3D: A Multi-Task Transformer-Based Model for 3D Medical Imaging
Fan Li, Arun Iyengar, Lanyu Xu
TL;DR
MTMed3D presents an end-to-end multi-task architecture for 3D medical imaging that jointly addresses detection, segmentation, and classification using a shared Swin Transformer encoder. The model combines task-specific decoders—a modified RetinaNet with PANet for detection, a Swin UNETR-inspired segmentation path, and a DenseNet-121 classification branch—while balancing tasks with GradNorm. Evaluated on BraTS 2018 and 2019, MTMed3D achieves state-of-the-art-like detection performance, competitive segmentation accuracy, and solid tumor grading, all with substantially reduced computational cost compared to training separate single-task models. The work highlights the practicality of Transformer-based multi-task learning in clinical settings, offering a unified, resource-efficient framework that can be adapted to other 3D medical imaging tasks.
Abstract
In the field of medical imaging, AI-assisted techniques such as object detection, segmentation, and classification are widely employed to alleviate the workload of physicians and doctors. However, single-task models are predominantly used, overlooking the shared information across tasks. This oversight leads to inefficiencies in real-life applications. In this work, we propose MTMed3D, a novel end-to-end Multi-task Transformer-based model to address the limitations of single-task models by jointly performing 3D detection, segmentation, and classification in medical imaging. Our model uses a Transformer as the shared encoder to generate multi-scale features, followed by CNN-based task-specific decoders. The proposed framework was evaluated on the BraTS 2018 and 2019 datasets, achieving promising results across all three tasks, especially in detection, where our method achieves better results than prior works. Additionally, we compare our multi-task model with equivalent single-task variants trained separately. Our multi-task model significantly reduces computational costs and achieves faster inference speed while maintaining comparable performance to the single-task models, highlighting its efficiency advantage. To the best of our knowledge, this is the first work to leverage Transformers for multi-task learning that simultaneously covers detection, segmentation, and classification tasks in 3D medical imaging, presenting its potential to enhance diagnostic processes. The code is available at https://github.com/fanlimua/MTMed3D.git.
