Table of Contents
Fetching ...

Rethinking Transfer Learning for Medical Image Classification

Le Peng, Hengyue Liang, Gaoxiang Luo, Taihui Li, Ju Sun

TL;DR

The paper addresses the suboptimality of full transfer learning in medical image classification (MIC) when data are limited. It proposes TruncatedTL (TTL), a simple cutoff-based approach that reuses bottom pretrained layers and discards the top layers, accompanied by a two-stage hierarchical search to identify effective truncation points and SVCCA-based transferability analysis. TTL consistently matches or surpasses existing differential TL methods (LWFT, TF) while yielding compact, faster models across 2D and 3D MIC tasks, with insights into feature reuse and when top layers are unnecessary. This work offers a practical, scalable TL strategy for MIC that reduces inference costs without sacrificing accuracy, supporting broader deployment in resource-constrained clinical settings.

Abstract

Transfer learning (TL) from pretrained deep models is a standard practice in modern medical image classification (MIC). However, what levels of features to be reused are problem-dependent, and uniformly finetuning all layers of pretrained models may be suboptimal. This insight has partly motivated the recent differential TL strategies, such as TransFusion (TF) and layer-wise finetuning (LWFT), which treat the layers in the pretrained models differentially. In this paper, we add one more strategy into this family, called TruncatedTL, which reuses and finetunes appropriate bottom layers and directly discards the remaining layers. This yields not only superior MIC performance but also compact models for efficient inference, compared to other differential TL methods. Our code is available at: https://github.com/sun-umn/TTL

Rethinking Transfer Learning for Medical Image Classification

TL;DR

The paper addresses the suboptimality of full transfer learning in medical image classification (MIC) when data are limited. It proposes TruncatedTL (TTL), a simple cutoff-based approach that reuses bottom pretrained layers and discards the top layers, accompanied by a two-stage hierarchical search to identify effective truncation points and SVCCA-based transferability analysis. TTL consistently matches or surpasses existing differential TL methods (LWFT, TF) while yielding compact, faster models across 2D and 3D MIC tasks, with insights into feature reuse and when top layers are unnecessary. This work offers a practical, scalable TL strategy for MIC that reduces inference costs without sacrificing accuracy, supporting broader deployment in resource-constrained clinical settings.

Abstract

Transfer learning (TL) from pretrained deep models is a standard practice in modern medical image classification (MIC). However, what levels of features to be reused are problem-dependent, and uniformly finetuning all layers of pretrained models may be suboptimal. This insight has partly motivated the recent differential TL strategies, such as TransFusion (TF) and layer-wise finetuning (LWFT), which treat the layers in the pretrained models differentially. In this paper, we add one more strategy into this family, called TruncatedTL, which reuses and finetunes appropriate bottom layers and directly discards the remaining layers. This yields not only superior MIC performance but also compact models for efficient inference, compared to other differential TL methods. Our code is available at: https://github.com/sun-umn/TTL

Paper Structure

This paper contains 19 sections, 2 equations, 9 figures, 4 tables.

Figures (9)

  • Figure 1: (left) The feature hierarchy learned by typical DCNNs, see Appendix E for details; (right) Examples of diseases in a chest x-ray NguyenEtAl2021VinBigData.
  • Figure 2: TL strategies and their usage scenarios
  • Figure 3: (i) Overview of typical TL setup, and the four TL methods that we focus on in this paper. (ii) Illustration of feature transferability and the performance of different levels of features on BIMCV. We take the ResNet50 model pretrained on ImageNet, and perform a full TL on BIMCV. We consider $17$ natural truncation/cutoff points that do not cut through the skip connections.
  • Figure 4: SVCCA on COVID-19 diagnosis task. Bold curve indicates the CCA coefficients for learned features, while light curve indicates the correlation for two uncorrelated random features. So the area between the two curves is a quantitative measure of the correlation between the said blocks. We normalize all the indices of the CCA coefficients to be $[0, 1]$. We compare (i) features learned from trained and finetuned model at the same layer, and (ii) features learned in the finetuned model but different blocks
  • Figure 5: Similar to TTL on ResNet/DenseNet, we identify the block structure in U-Net and truncate at the intersection between blocks. The image of U-Net is adapted from ronneberger2015u
  • ...and 4 more figures