MedNNS: Supernet-based Medical Task-Adaptive Neural Network Search
Lotfi Abdelkrim Mecharbat, Ibrahim Almakky, Martin Takac, Mohammad Yaqub
TL;DR
The paper tackles adapting DL to medical imaging by jointly optimizing architecture and initialization, addressing the limitations of ImageNet pretraining due to domain shift. It introduces MedNNS, a Supernet-based medical neural network search that builds a large meta-space from per-dataset Supernet subnetworks, encodes architectures and datasets into embeddings, and uses a composite loss with rank and FID terms to align task-specific representations. The authors demonstrate that this approach yields a gain of about 1.7% average accuracy across MedMNIST datasets and substantially faster convergence compared with both ImageNet-pretrained baselines and existing NAS methods. While effective across diverse tasks, generalization gaps remain for highly dissimilar datasets, motivating future work to broaden dataset coverage and incorporate hardware constraints to enhance applicability.
Abstract
Deep learning (DL) has achieved remarkable progress in the field of medical imaging. However, adapting DL models to medical tasks remains a significant challenge, primarily due to two key factors: (1) architecture selection, as different tasks necessitate specialized model designs, and (2) weight initialization, which directly impacts the convergence speed and final performance of the models. Although transfer learning from ImageNet is a widely adopted strategy, its effectiveness is constrained by the substantial differences between natural and medical images. To address these challenges, we introduce Medical Neural Network Search (MedNNS), the first Neural Network Search framework for medical imaging applications. MedNNS jointly optimizes architecture selection and weight initialization by constructing a meta-space that encodes datasets and models based on how well they perform together. We build this space using a Supernetwork-based approach, expanding the model zoo size by 51x times over previous state-of-the-art (SOTA) methods. Moreover, we introduce rank loss and Fréchet Inception Distance (FID) loss into the construction of the space to capture inter-model and inter-dataset relationships, thereby achieving more accurate alignment in the meta-space. Experimental results across multiple datasets demonstrate that MedNNS significantly outperforms both ImageNet pre-trained DL models and SOTA Neural Architecture Search (NAS) methods, achieving an average accuracy improvement of 1.7% across datasets while converging substantially faster. The code and the processed meta-space is available at https://github.com/BioMedIA-MBZUAI/MedNNS.
