Table of Contents
Fetching ...

Pretrained ViTs Yield Versatile Representations For Medical Images

Christos Matsoukas, Johan Fredin Haslum, Moein Sorkhei, Magnus Söderberg, Kevin Smith

TL;DR

While CNNs perform better if trained from scratch, off-the-shelf vision transformers can perform on par with CNNs when pretrained on ImageNet, both in a supervised and self-supervised setting, rendering them as a viable alternative to CNNs.

Abstract

Convolutional Neural Networks (CNNs) have reigned for a decade as the de facto approach to automated medical image diagnosis, pushing the state-of-the-art in classification, detection and segmentation tasks. Over the last years, vision transformers (ViTs) have appeared as a competitive alternative to CNNs, yielding impressive levels of performance in the natural image domain, while possessing several interesting properties that could prove beneficial for medical imaging tasks. In this work, we explore the benefits and drawbacks of transformer-based models for medical image classification. We conduct a series of experiments on several standard 2D medical image benchmark datasets and tasks. Our findings show that, while CNNs perform better if trained from scratch, off-the-shelf vision transformers can perform on par with CNNs when pretrained on ImageNet, both in a supervised and self-supervised setting, rendering them as a viable alternative to CNNs.

Pretrained ViTs Yield Versatile Representations For Medical Images

TL;DR

While CNNs perform better if trained from scratch, off-the-shelf vision transformers can perform on par with CNNs when pretrained on ImageNet, both in a supervised and self-supervised setting, rendering them as a viable alternative to CNNs.

Abstract

Convolutional Neural Networks (CNNs) have reigned for a decade as the de facto approach to automated medical image diagnosis, pushing the state-of-the-art in classification, detection and segmentation tasks. Over the last years, vision transformers (ViTs) have appeared as a competitive alternative to CNNs, yielding impressive levels of performance in the natural image domain, while possessing several interesting properties that could prove beneficial for medical imaging tasks. In this work, we explore the benefits and drawbacks of transformer-based models for medical image classification. We conduct a series of experiments on several standard 2D medical image benchmark datasets and tasks. Our findings show that, while CNNs perform better if trained from scratch, off-the-shelf vision transformers can perform on par with CNNs when pretrained on ImageNet, both in a supervised and self-supervised setting, rendering them as a viable alternative to CNNs.
Paper Structure (25 sections, 5 figures, 5 tables)

This paper contains 25 sections, 5 figures, 5 tables.

Figures (5)

  • Figure 1: Performance comparison of ResNet50 and DeiT-S, two commonly used CNN-based and ViT-based architectures. The comparison covers several standard medical image classification datasets and different types of initialization including random init, ImageNet pretraining, and self-supervision using DINO caron2021emerging. Performance is measured after fine-tuning on the dataset, as well as using $k$-NN evaluation without fine-tuning. We report the median over 5 repetitions, error bars represent standard deviation. Numeric values appear in Table \ref{['tab:knn-eval']} (Appendix \ref{['apx:knn']})
  • Figure 2: Medical image segmentation results comparing DeepLab3-ResNet50 (blue), DeepLab3-DeiT-S (red). Ground truth mask appears in yellow. Note that the ViT segmentations tend to do a better job of segmenting distant regions.
  • Figure 3: Impact of model capacity on performance for the ResNet and DeiT families on standard medical image classification datasets. Both model types seem to perform better with increasing capacity, roughly scaling similarly. Numeric results appear in Table \ref{['tab:capacity']} of Appendix \ref{['apx:capacity']}.
  • Figure 4: Comparing saliency for ResNet50 (2$^{nd}$ row) and DeiT-S (3$^{rd}$ row) on medical classification. Each column contains the original, a Grad-CAM visualization visualisation for ResNet50 selvaraju2017grad and the top-$50\%$ attention map of the cls token of DeiT-S.
  • Figure 5: Mean attention distance with respect to the attention head and the network depth. Each point is calculated as the average over 512 test samples as the mean of the element-wise multiplication of each query token's attention and its distance from the other tokens