Table of Contents
Fetching ...

Learnable Prompting SAM-induced Knowledge Distillation for Semi-supervised Medical Image Segmentation

Kaiwen Huang, Tao Zhou, Huazhu Fu, Yizhe Zhang, Yi Zhou, Chen Gong, Dong Liang

TL;DR

KnowSAM tackles the challenge of limited labeled data in medical image segmentation by distilling SAM's generalization through a learnable prompting and multi-view co-training framework. The approach combines two independently updated subnets with a Hybrid Aggregation Module, a Learnable Prompt Strategy that generates dense prompts, and SAM-induced Knowledge Distillation to transfer guidance from SAM to the subnets, complemented by Uncertainty-Guided Data Augmentation. Empirical results across colonoscopy, ultrasound, ISIC-2018, ACDC, and BCSS demonstrate state-of-the-art performance and robustness, with ablations confirming the effectiveness of each component. The framework is adaptable and can enhance other semi-supervised segmentation methods, offering practical gains for resource-constrained clinical imaging tasks.

Abstract

The limited availability of labeled data has driven advancements in semi-supervised learning for medical image segmentation. Modern large-scale models tailored for general segmentation, such as the Segment Anything Model (SAM), have revealed robust generalization capabilities. However, applying these models directly to medical image segmentation still exposes performance degradation. In this paper, we propose a learnable prompting SAM-induced Knowledge distillation framework (KnowSAM) for semi-supervised medical image segmentation. Firstly, we propose a Multi-view Co-training (MC) strategy that employs two distinct sub-networks to employ a co-teaching paradigm, resulting in more robust outcomes. Secondly, we present a Learnable Prompt Strategy (LPS) to dynamically produce dense prompts and integrate an adapter to fine-tune SAM specifically for medical image segmentation tasks. Moreover, we propose SAM-induced Knowledge Distillation (SKD) to transfer useful knowledge from SAM to two sub-networks, enabling them to learn from SAM's predictions and alleviate the effects of incorrect pseudo-labels during training. Notably, the predictions generated by our subnets are used to produce mask prompts for SAM, facilitating effective inter-module information exchange. Extensive experimental results on various medical segmentation tasks demonstrate that our model outperforms the state-of-the-art semi-supervised segmentation approaches. Crucially, our SAM distillation framework can be seamlessly integrated into other semi-supervised segmentation methods to enhance performance. The code will be released upon acceptance of this manuscript at: https://github.com/taozh2017/KnowSAM

Learnable Prompting SAM-induced Knowledge Distillation for Semi-supervised Medical Image Segmentation

TL;DR

KnowSAM tackles the challenge of limited labeled data in medical image segmentation by distilling SAM's generalization through a learnable prompting and multi-view co-training framework. The approach combines two independently updated subnets with a Hybrid Aggregation Module, a Learnable Prompt Strategy that generates dense prompts, and SAM-induced Knowledge Distillation to transfer guidance from SAM to the subnets, complemented by Uncertainty-Guided Data Augmentation. Empirical results across colonoscopy, ultrasound, ISIC-2018, ACDC, and BCSS demonstrate state-of-the-art performance and robustness, with ablations confirming the effectiveness of each component. The framework is adaptable and can enhance other semi-supervised segmentation methods, offering practical gains for resource-constrained clinical imaging tasks.

Abstract

The limited availability of labeled data has driven advancements in semi-supervised learning for medical image segmentation. Modern large-scale models tailored for general segmentation, such as the Segment Anything Model (SAM), have revealed robust generalization capabilities. However, applying these models directly to medical image segmentation still exposes performance degradation. In this paper, we propose a learnable prompting SAM-induced Knowledge distillation framework (KnowSAM) for semi-supervised medical image segmentation. Firstly, we propose a Multi-view Co-training (MC) strategy that employs two distinct sub-networks to employ a co-teaching paradigm, resulting in more robust outcomes. Secondly, we present a Learnable Prompt Strategy (LPS) to dynamically produce dense prompts and integrate an adapter to fine-tune SAM specifically for medical image segmentation tasks. Moreover, we propose SAM-induced Knowledge Distillation (SKD) to transfer useful knowledge from SAM to two sub-networks, enabling them to learn from SAM's predictions and alleviate the effects of incorrect pseudo-labels during training. Notably, the predictions generated by our subnets are used to produce mask prompts for SAM, facilitating effective inter-module information exchange. Extensive experimental results on various medical segmentation tasks demonstrate that our model outperforms the state-of-the-art semi-supervised segmentation approaches. Crucially, our SAM distillation framework can be seamlessly integrated into other semi-supervised segmentation methods to enhance performance. The code will be released upon acceptance of this manuscript at: https://github.com/taozh2017/KnowSAM

Paper Structure

This paper contains 34 sections, 14 equations, 7 figures, 9 tables.

Figures (7)

  • Figure 1: A comparison of different semi-supervised frameworks: (a) MT framework with a student model and a teacher model, (b) co-training framework with two subnets, (c) our proposed semi-supervised framework without SAM, which enhances the dual-stream network with Multi-view Co-training (MC), and (d) our model with Knowledge Distillation (KD) from SAM.
  • Figure 2: Overview of our KnowSAM framework. An input image is initially processed by two distinct subnets, $\mathcal{F}_A$ and $\mathcal{F}_B$ to obtain $\hat{Y}_a$ and $\hat{Y}_b$, which are then fed into a hybrid aggregation module to produce the composite map $\hat{Y}_f$. Concurrently, the input image is processed by the SAM encoder to extract feature embeddings, which are refined by $\psi(\cdot)$ to produce the learnable feature prompt. The two types of prompts ($\hat{Y}_f$ and the learnable feature prompt) are provided to the SAM decoder to produce predictions, which serve as the basis for knowledge distillation. Furthermore, we leverage an uncertainty-guided data augmentation approach to generate new training samples for enhancing robustness. .
  • Figure 3: Pipeline of the proposed UGDA strategy. (a) Labeled data undergo weak augmentation with ground-truth labels, while unlabeled data receive strong augmentation with pseudo labels. (b) An uncertainty map is obtained from the aggregated prediction $\hat{Y}_f$, and the top five regions with the highest uncertainty are identified. (c) Interactive bidirectional copy-paste is employed based on the uncertainty map (a single region is depicted here for demonstration purposes). (d) Mixed images with labels are used as additional training samples.
  • Figure 4: Visual results of different methods on five segmentation tasks. The comparison methods include MT and BCP, which utilize teacher-student architectures, while DTC, MC+, MCF, CDMA, CauSSL, and our method are based on consistency learning architectures. The first two rows correspond to endoscopic images, while the remaining rows represent results on the ISIC-2018, ACDC, ultrasound, and pathological image datasets, respectively.
  • Figure 5: Visualization results of progressively adding different view information without SAM distillation using $30\%$ labeled data.
  • ...and 2 more figures