Table of Contents
Fetching ...

SeqSAM: Autoregressive Multiple Hypothesis Prediction for Medical Image Segmentation using SAM

Benjamin Towle, Xin Chen, Ke Zhou

TL;DR

Medical image segmentation often involves uncertainty with multiple plausible annotations. SeqSAM introduces an autoregressive extension of the Segment Anything Model that generates a sequence of masks $\{\hat{\mathbf{y}}^{(m)}\}_{m=1}^M$ conditioned on previous outputs, trained with a set-based Hungarian loss to align each prediction to one of the $K$ ground-truth masks. On the datasets LIDC-IDRI and QUBIQ Kidney, SeqSAM achieves state-of-the-art performance in $D_{avg}$ and $GED$, while supporting an arbitrary output count $M$ without retraining. This approach yields multiple clinically relevant segmentation hypotheses, enabling more robust decision making in practice.

Abstract

Pre-trained segmentation models are a powerful and flexible tool for segmenting images. Recently, this trend has extended to medical imaging. Yet, often these methods only produce a single prediction for a given image, neglecting inherent uncertainty in medical images, due to unclear object boundaries and errors caused by the annotation tool. Multiple Choice Learning is a technique for generating multiple masks, through multiple learned prediction heads. However, this cannot readily be extended to producing more outputs than its initial pre-training hyperparameters, as the sparse, winner-takes-all loss function makes it easy for one prediction head to become overly dominant, thus not guaranteeing the clinical relevancy of each mask produced. We introduce SeqSAM, a sequential, RNN-inspired approach to generating multiple masks, which uses a bipartite matching loss for ensuring the clinical relevancy of each mask, and can produce an arbitrary number of masks. We show notable improvements in quality of each mask produced across two publicly available datasets. Our code is available at https://github.com/BenjaminTowle/SeqSAM.

SeqSAM: Autoregressive Multiple Hypothesis Prediction for Medical Image Segmentation using SAM

TL;DR

Medical image segmentation often involves uncertainty with multiple plausible annotations. SeqSAM introduces an autoregressive extension of the Segment Anything Model that generates a sequence of masks conditioned on previous outputs, trained with a set-based Hungarian loss to align each prediction to one of the ground-truth masks. On the datasets LIDC-IDRI and QUBIQ Kidney, SeqSAM achieves state-of-the-art performance in and , while supporting an arbitrary output count without retraining. This approach yields multiple clinically relevant segmentation hypotheses, enabling more robust decision making in practice.

Abstract

Pre-trained segmentation models are a powerful and flexible tool for segmenting images. Recently, this trend has extended to medical imaging. Yet, often these methods only produce a single prediction for a given image, neglecting inherent uncertainty in medical images, due to unclear object boundaries and errors caused by the annotation tool. Multiple Choice Learning is a technique for generating multiple masks, through multiple learned prediction heads. However, this cannot readily be extended to producing more outputs than its initial pre-training hyperparameters, as the sparse, winner-takes-all loss function makes it easy for one prediction head to become overly dominant, thus not guaranteeing the clinical relevancy of each mask produced. We introduce SeqSAM, a sequential, RNN-inspired approach to generating multiple masks, which uses a bipartite matching loss for ensuring the clinical relevancy of each mask, and can produce an arbitrary number of masks. We show notable improvements in quality of each mask produced across two publicly available datasets. Our code is available at https://github.com/BenjaminTowle/SeqSAM.

Paper Structure

This paper contains 13 sections, 2 equations, 2 figures, 1 table.

Figures (2)

  • Figure 1: (A) Overview of SeqSAM comprising an encoder ($\mathop{\mathrm{SAM_{enc}}}\limits$) and a decoder ($\mathop{\mathrm{SAM_{dec}}}\limits$). (B) A Recurrent Module sequentially generates masks. (C) When more predictions than labels are generated, we chunk the predictions into groups where $C^{(k)}$ represents the $k-$th group and perform subsampling. Image and annotations are cropped to highlight region-of-interest.
  • Figure 2: Qualitative results on LIDC-IDRI test set using SAM, showing annotated lung nodules. (Col. 1) Input images; (Col. 2-5) labels from multiple human annotators; (Col. 6-8) MCL baseline sam; (Col. 9-11) SeqSAM, our method, for $M = 3$.