Table of Contents
Fetching ...

Fine-tuning a Multiple Instance Learning Feature Extractor with Masked Context Modelling and Knowledge Distillation

Juan I. Pisula, Katarzyna Bozek

TL;DR

A single epoch of the proposed task suffices to increase the downstream performance of the feature-extractor model when used in a MIL scenario, even capable of outperforming the downstream performance of the teacher model, while being considerably smaller and requiring a fraction of its compute.

Abstract

The first step in Multiple Instance Learning (MIL) algorithms for Whole Slide Image (WSI) classification consists of tiling the input image into smaller patches and computing their feature vectors produced by a pre-trained feature extractor model. Feature extractor models that were pre-trained with supervision on ImageNet have proven to transfer well to this domain, however, this pre-training task does not take into account that visual information in neighboring patches is highly correlated. Based on this observation, we propose to increase downstream MIL classification by fine-tuning the feature extractor model using \textit{Masked Context Modelling with Knowledge Distillation}. In this task, the feature extractor model is fine-tuned by predicting masked patches in a bigger context window. Since reconstructing the input image would require a powerful image generation model, and our goal is not to generate realistically looking image patches, we predict instead the feature vectors produced by a larger teacher network. A single epoch of the proposed task suffices to increase the downstream performance of the feature-extractor model when used in a MIL scenario, even capable of outperforming the downstream performance of the teacher model, while being considerably smaller and requiring a fraction of its compute.

Fine-tuning a Multiple Instance Learning Feature Extractor with Masked Context Modelling and Knowledge Distillation

TL;DR

A single epoch of the proposed task suffices to increase the downstream performance of the feature-extractor model when used in a MIL scenario, even capable of outperforming the downstream performance of the teacher model, while being considerably smaller and requiring a fraction of its compute.

Abstract

The first step in Multiple Instance Learning (MIL) algorithms for Whole Slide Image (WSI) classification consists of tiling the input image into smaller patches and computing their feature vectors produced by a pre-trained feature extractor model. Feature extractor models that were pre-trained with supervision on ImageNet have proven to transfer well to this domain, however, this pre-training task does not take into account that visual information in neighboring patches is highly correlated. Based on this observation, we propose to increase downstream MIL classification by fine-tuning the feature extractor model using \textit{Masked Context Modelling with Knowledge Distillation}. In this task, the feature extractor model is fine-tuned by predicting masked patches in a bigger context window. Since reconstructing the input image would require a powerful image generation model, and our goal is not to generate realistically looking image patches, we predict instead the feature vectors produced by a larger teacher network. A single epoch of the proposed task suffices to increase the downstream performance of the feature-extractor model when used in a MIL scenario, even capable of outperforming the downstream performance of the teacher model, while being considerably smaller and requiring a fraction of its compute.
Paper Structure (14 sections, 1 equation, 3 figures, 5 tables)

This paper contains 14 sections, 1 equation, 3 figures, 5 tables.

Figures (3)

  • Figure 1: a) A cutout of a Breast Carcinoma HE slide, where a highlighted image patch shows a cluster of cells. When inspecting its neighborhood, it is seen that this cluster is not an isolated pattern, but part of a mammary lobe. Masked Context Modelling with Knowledge Distillation aims to improve downstream performance by including context information in the feature extraction step. b) Comparison (number of parameters, FLOPs per forward pass, downstream MIL classification task AUROC) of ImageNet pre-trained feature extraction models: EfficientNetV2-L, ResNet18, and ResNet18 fine-tuned with our method using the EfficientNetV2-L as teacher. CLAM was used as MIL classification model, and performance is visualized relative to the EfficientNetV2-L model.
  • Figure 2: Proposed pipeline. During the feature extractor fine-tuning stage (left), a pre-trained feature extractor model is fed with image patches coming from a larger context window. A random subset of the patches' feature vector representations is masked, and a Transformer encoder with a predictor network is used to predict the masked instances' feature vector representations produced by a frozen teacher network, minimizing an $l_1$ loss. For the downstream task training stage (right), the Transformer and the predictor networks are discarded, and the fine-tuned feature extractor can be used in any Multiple Instance Learning pipeline.
  • Figure 3: Heatmap visualizations of the cosine similarity between the feature vector of a patch from a metastasis region and the rest of the feature vectors of the slide.