Table of Contents
Fetching ...

CoSAM: Self-Correcting SAM for Domain Generalization in 2D Medical Image Segmentation

Yihang Fu, Ziyang Chen, Yiwen Ye, Xingliang Lei, Zhisong Wang, Yong Xia

TL;DR

This work generates coarse masks using SAM in a prompt-free manner, providing prior prompts for the subsequent stages, and eliminating the need for prompt generators, and introduces a generalized error decoder that simulates the correction process typically performed by clinicians.

Abstract

Medical images often exhibit distribution shifts due to variations in imaging protocols and scanners across different medical centers. Domain Generalization (DG) methods aim to train models on source domains that can generalize to unseen target domains. Recently, the segment anything model (SAM) has demonstrated strong generalization capabilities due to its prompt-based design, and has gained significant attention in image segmentation tasks. Existing SAM-based approaches attempt to address the need for manual prompts by introducing prompt generators that automatically generate these prompts. However, we argue that auto-generated prompts may not be sufficiently accurate under distribution shifts, potentially leading to incorrect predictions that still require manual verification and correction by clinicians. To address this challenge, we propose a method for 2D medical image segmentation called Self-Correcting SAM (CoSAM). Our approach begins by generating coarse masks using SAM in a prompt-free manner, providing prior prompts for the subsequent stages, and eliminating the need for prompt generators. To automatically refine these coarse masks, we introduce a generalized error decoder that simulates the correction process typically performed by clinicians. Furthermore, we generate diverse prompts as feedback based on the corrected masks, which are used to iteratively refine the predictions within a self-correcting loop, enhancing the generalization performance of our model. Extensive experiments on two medical image segmentation benchmarks across multiple scenarios demonstrate the superiority of CoSAM over state-of-the-art SAM-based methods.

CoSAM: Self-Correcting SAM for Domain Generalization in 2D Medical Image Segmentation

TL;DR

This work generates coarse masks using SAM in a prompt-free manner, providing prior prompts for the subsequent stages, and eliminating the need for prompt generators, and introduces a generalized error decoder that simulates the correction process typically performed by clinicians.

Abstract

Medical images often exhibit distribution shifts due to variations in imaging protocols and scanners across different medical centers. Domain Generalization (DG) methods aim to train models on source domains that can generalize to unseen target domains. Recently, the segment anything model (SAM) has demonstrated strong generalization capabilities due to its prompt-based design, and has gained significant attention in image segmentation tasks. Existing SAM-based approaches attempt to address the need for manual prompts by introducing prompt generators that automatically generate these prompts. However, we argue that auto-generated prompts may not be sufficiently accurate under distribution shifts, potentially leading to incorrect predictions that still require manual verification and correction by clinicians. To address this challenge, we propose a method for 2D medical image segmentation called Self-Correcting SAM (CoSAM). Our approach begins by generating coarse masks using SAM in a prompt-free manner, providing prior prompts for the subsequent stages, and eliminating the need for prompt generators. To automatically refine these coarse masks, we introduce a generalized error decoder that simulates the correction process typically performed by clinicians. Furthermore, we generate diverse prompts as feedback based on the corrected masks, which are used to iteratively refine the predictions within a self-correcting loop, enhancing the generalization performance of our model. Extensive experiments on two medical image segmentation benchmarks across multiple scenarios demonstrate the superiority of CoSAM over state-of-the-art SAM-based methods.

Paper Structure

This paper contains 18 sections, 9 equations, 3 figures, 5 tables, 2 algorithms.

Figures (3)

  • Figure 1: (a) Prompt-free methods directly utilize SAM to produce predictions without prompts. (b) Prompt-based methods train a prompt generator to produce prompts automatically to assist in segmentation. (c) Our CoSAM constructs a corrector to simulate clinicians to correct the coarse masks to produce more accurate prompts and then refines the predictions within the self-correcting loop.
  • Figure 2: Overview of our proposed CoSAM. (a) The training process of our CoSAM. For the training image and corresponding label, (1) CoSAM first employs a fine-tuned mask decoder to produce a coarse mask without prompts using the image embeddings obtained by the frozen image encoder. The mask decoder is trained in a prompt-free manner. (2) We feed the coarse mask into the mask encoder to extract mask embeddings. The concatenation of mask embeddings and image embeddings is then used as input to the error decoder to generate an error map, where the error decoder is trained to evaluate the quality of the coarse mask. (3) Based on the coarse mask/label, we generate refined/guided prompts to obtain the refined/guided mask, and the mask decoder is trained to utilize defective/perfect prompts to assist in segmentation. (b) The inference process of our CoSAM. For each test image, we first generate the coarse mask and error map similar to the training process. After that, we correct the error points within the coarse mask based on the error map to generate corrected prompts as feedback. These prompts are then fed into the prompt encoder to produce prompt embeddings, and the mask decoder produces the refined mask using image and prompt embeddings. We repeat the above refinement process for $T$ iterations and also introduce an early-stop mechanism that terminates refinement when the number of error points in the error map increases. Best viewed in color.
  • Figure 3: Overall performance of our CoSAM with various $\alpha$, $K$, and $T$ on the prostate segmentation task.