Table of Contents
Fetching ...

MedNNS: Supernet-based Medical Task-Adaptive Neural Network Search

Lotfi Abdelkrim Mecharbat, Ibrahim Almakky, Martin Takac, Mohammad Yaqub

TL;DR

The paper tackles adapting DL to medical imaging by jointly optimizing architecture and initialization, addressing the limitations of ImageNet pretraining due to domain shift. It introduces MedNNS, a Supernet-based medical neural network search that builds a large meta-space from per-dataset Supernet subnetworks, encodes architectures and datasets into embeddings, and uses a composite loss with rank and FID terms to align task-specific representations. The authors demonstrate that this approach yields a gain of about 1.7% average accuracy across MedMNIST datasets and substantially faster convergence compared with both ImageNet-pretrained baselines and existing NAS methods. While effective across diverse tasks, generalization gaps remain for highly dissimilar datasets, motivating future work to broaden dataset coverage and incorporate hardware constraints to enhance applicability.

Abstract

Deep learning (DL) has achieved remarkable progress in the field of medical imaging. However, adapting DL models to medical tasks remains a significant challenge, primarily due to two key factors: (1) architecture selection, as different tasks necessitate specialized model designs, and (2) weight initialization, which directly impacts the convergence speed and final performance of the models. Although transfer learning from ImageNet is a widely adopted strategy, its effectiveness is constrained by the substantial differences between natural and medical images. To address these challenges, we introduce Medical Neural Network Search (MedNNS), the first Neural Network Search framework for medical imaging applications. MedNNS jointly optimizes architecture selection and weight initialization by constructing a meta-space that encodes datasets and models based on how well they perform together. We build this space using a Supernetwork-based approach, expanding the model zoo size by 51x times over previous state-of-the-art (SOTA) methods. Moreover, we introduce rank loss and Fréchet Inception Distance (FID) loss into the construction of the space to capture inter-model and inter-dataset relationships, thereby achieving more accurate alignment in the meta-space. Experimental results across multiple datasets demonstrate that MedNNS significantly outperforms both ImageNet pre-trained DL models and SOTA Neural Architecture Search (NAS) methods, achieving an average accuracy improvement of 1.7% across datasets while converging substantially faster. The code and the processed meta-space is available at https://github.com/BioMedIA-MBZUAI/MedNNS.

MedNNS: Supernet-based Medical Task-Adaptive Neural Network Search

TL;DR

The paper tackles adapting DL to medical imaging by jointly optimizing architecture and initialization, addressing the limitations of ImageNet pretraining due to domain shift. It introduces MedNNS, a Supernet-based medical neural network search that builds a large meta-space from per-dataset Supernet subnetworks, encodes architectures and datasets into embeddings, and uses a composite loss with rank and FID terms to align task-specific representations. The authors demonstrate that this approach yields a gain of about 1.7% average accuracy across MedMNIST datasets and substantially faster convergence compared with both ImageNet-pretrained baselines and existing NAS methods. While effective across diverse tasks, generalization gaps remain for highly dissimilar datasets, motivating future work to broaden dataset coverage and incorporate hardware constraints to enhance applicability.

Abstract

Deep learning (DL) has achieved remarkable progress in the field of medical imaging. However, adapting DL models to medical tasks remains a significant challenge, primarily due to two key factors: (1) architecture selection, as different tasks necessitate specialized model designs, and (2) weight initialization, which directly impacts the convergence speed and final performance of the models. Although transfer learning from ImageNet is a widely adopted strategy, its effectiveness is constrained by the substantial differences between natural and medical images. To address these challenges, we introduce Medical Neural Network Search (MedNNS), the first Neural Network Search framework for medical imaging applications. MedNNS jointly optimizes architecture selection and weight initialization by constructing a meta-space that encodes datasets and models based on how well they perform together. We build this space using a Supernetwork-based approach, expanding the model zoo size by 51x times over previous state-of-the-art (SOTA) methods. Moreover, we introduce rank loss and Fréchet Inception Distance (FID) loss into the construction of the space to capture inter-model and inter-dataset relationships, thereby achieving more accurate alignment in the meta-space. Experimental results across multiple datasets demonstrate that MedNNS significantly outperforms both ImageNet pre-trained DL models and SOTA Neural Architecture Search (NAS) methods, achieving an average accuracy improvement of 1.7% across datasets while converging substantially faster. The code and the processed meta-space is available at https://github.com/BioMedIA-MBZUAI/MedNNS.

Paper Structure

This paper contains 6 sections, 5 equations, 3 figures, 2 tables.

Figures (3)

  • Figure 1: (Left) Heatmap showing pairwise FID distance between MedMNIST datasets and ImageNet. (Right) Training curves for ResNet-18 using random, ImageNet, and nearest-FID dataset pertaining.
  • Figure 2: Overview of our MedNNS framework. (A) During training: (A.1) A large model zoo is built by training a single Supernetwork per dataset and extracting thousands of subnetworks via weight sharing. (A.2) Models and datasets are embedded into the latent space. (A.3) The meta-space is optimized using a combination of rank loss, FID loss, and performance loss to align models and datasets according to their relative performance. (B) During inference, given an unseen dataset, its embedding is computed and used to query the meta-space, selecting the closest model embedding as the most suitable pre-trained model.
  • Figure 3: T-SNE visualization of dataset (query) and model embeddings in the meta-space. The central plot shows the MedNNS meta-space. The zoomed-in side plots provide a closer view of specific regions, illustrating the spatial arrangement of models, colored according to their true accuracy, around their corresponding datasets.