Table of Contents
Fetching ...

Show and Segment: Universal Medical Image Segmentation via In-Context Learning

Yunhe Gao, Di Liu, Zhuowei Li, Yunsheng Li, Dongdong Chen, Mu Zhou, Dimitris N. Metaxas

TL;DR

This work introduces Iris, a decoupled in-context learning framework for universal 3D medical image segmentation that conditions a single segmentation function on a small set of reference image-label pairs. It combines a Task Encoding Module, which distills foreground and contextual cues into class-specific embeddings, with a Mask Decoding Module that uses cross-attention to generate multi-class masks in one forward pass, enabling one-shot inference, context ensembles, and object-level retrieval without fine-tuning. Episodic training with Dice and cross-entropy losses, extensive data augmentation, and a memory-bank extension for seen classes support robust generalization across 12 upstream datasets and seven held-out datasets, including unseen anatomical structures. Iris achieves state-of-the-art performance on in-distribution tasks and strong out-of-distribution generalization, while offering efficient inference ($O(k+m)$ complexity) and practical inference strategies for high-throughput clinical workflows, and it automatically uncovers meaningful anatomical relationships from segmentation masks without explicit supervision.

Abstract

Medical image segmentation remains challenging due to the vast diversity of anatomical structures, imaging modalities, and segmentation tasks. While deep learning has made significant advances, current approaches struggle to generalize as they require task-specific training or fine-tuning on unseen classes. We present Iris, a novel In-context Reference Image guided Segmentation framework that enables flexible adaptation to novel tasks through the use of reference examples without fine-tuning. At its core, Iris features a lightweight context task encoding module that distills task-specific information from reference context image-label pairs. This rich context embedding information is used to guide the segmentation of target objects. By decoupling task encoding from inference, Iris supports diverse strategies from one-shot inference and context example ensemble to object-level context example retrieval and in-context tuning. Through comprehensive evaluation across twelve datasets, we demonstrate that Iris performs strongly compared to task-specific models on in-distribution tasks. On seven held-out datasets, Iris shows superior generalization to out-of-distribution data and unseen classes. Further, Iris's task encoding module can automatically discover anatomical relationships across datasets and modalities, offering insights into medical objects without explicit anatomical supervision.

Show and Segment: Universal Medical Image Segmentation via In-Context Learning

TL;DR

This work introduces Iris, a decoupled in-context learning framework for universal 3D medical image segmentation that conditions a single segmentation function on a small set of reference image-label pairs. It combines a Task Encoding Module, which distills foreground and contextual cues into class-specific embeddings, with a Mask Decoding Module that uses cross-attention to generate multi-class masks in one forward pass, enabling one-shot inference, context ensembles, and object-level retrieval without fine-tuning. Episodic training with Dice and cross-entropy losses, extensive data augmentation, and a memory-bank extension for seen classes support robust generalization across 12 upstream datasets and seven held-out datasets, including unseen anatomical structures. Iris achieves state-of-the-art performance on in-distribution tasks and strong out-of-distribution generalization, while offering efficient inference ( complexity) and practical inference strategies for high-throughput clinical workflows, and it automatically uncovers meaningful anatomical relationships from segmentation masks without explicit supervision.

Abstract

Medical image segmentation remains challenging due to the vast diversity of anatomical structures, imaging modalities, and segmentation tasks. While deep learning has made significant advances, current approaches struggle to generalize as they require task-specific training or fine-tuning on unseen classes. We present Iris, a novel In-context Reference Image guided Segmentation framework that enables flexible adaptation to novel tasks through the use of reference examples without fine-tuning. At its core, Iris features a lightweight context task encoding module that distills task-specific information from reference context image-label pairs. This rich context embedding information is used to guide the segmentation of target objects. By decoupling task encoding from inference, Iris supports diverse strategies from one-shot inference and context example ensemble to object-level context example retrieval and in-context tuning. Through comprehensive evaluation across twelve datasets, we demonstrate that Iris performs strongly compared to task-specific models on in-distribution tasks. On seven held-out datasets, Iris shows superior generalization to out-of-distribution data and unseen classes. Further, Iris's task encoding module can automatically discover anatomical relationships across datasets and modalities, offering insights into medical objects without explicit anatomical supervision.

Paper Structure

This paper contains 16 sections, 7 equations, 8 figures, 5 tables, 1 algorithm.

Figures (8)

  • Figure 1: Comparison of medical image segmentation approaches. A) Task-specific models require training separate models for each task, limiting their flexibility and scalability. B) Multi-task universal models can handle diverse tasks and imaging modalities, but fail on novel classes. C) SAM-based foundation models enable flexible segmentation through user interactions, but impractical for high-throughput automated processing. D) Our proposed Iris combines automatic processing with flexible adaptation via in-context learning, enabling both seen and unseen task segmentation without any manual interaction or retraining.
  • Figure 2: Overview of Iris framework. We design a task encoding module to extract compact task embeddings from reference examples to guide query image segmentation with the mask decoding module, enabling efficient and flexible adaptation to new tasks without finetuning.
  • Figure 3: Iris's flexible inference strategies. The red arrows indicates gradient backpropagation.
  • Figure 4: Analysis of different inference strategies.
  • Figure 5: Top: Visualizing the task embedding with t-SNE. The color represents dataset, the circle and marks are the classes of the embeddings. Bottom: Examples of the similar tasks revealed by the t-SNE result.
  • ...and 3 more figures