Table of Contents
Fetching ...

J-RAS: Enhancing Medical Image Segmentation via Retrieval-Augmented Joint Training

Salma J. Ahmed, Emad A. Mohammed, Azam Asilian Bidgoli

TL;DR

Medical image segmentation suffers from data scarcity and inter-patient variability, limiting generalization. J-RAS tackles this by jointly training a segmentation model with a retrieval model, using top-$K$ retrieved image–mask guides to provide contextual priors, and updating both networks in an end-to-end like loop with a shared segmentation loss. Across four backbones (U-Net, TransUNet, SegFormer, SAM) and two cardiac MRI datasets (ACDC and M&Ms), J-RAS yields consistent improvements in Dice and Hausdorff Distance, e.g., SegFormer with J-RAS improves mean Dice from $0.8708\\pm0.042$ to $0.9115\pm0.031$ and HD from $1.8130\pm2.49$ to $1.1489\pm0.30$, demonstrating enhanced boundary delineation and robustness. The approach reduces reliance on large annotated datasets and improves generalization, with practical impact for reliable, context-informed medical image analysis.

Abstract

Image segmentation, the process of dividing images into meaningful regions, is critical in medical applications for accurate diagnosis, treatment planning, and disease monitoring. Although manual segmentation by healthcare professionals produces precise outcomes, it is time-consuming, costly, and prone to variability due to differences in human expertise. Artificial intelligence (AI)-based methods have been developed to address these limitations by automating segmentation tasks; however, they often require large, annotated datasets that are rarely available in practice and frequently struggle to generalize across diverse imaging conditions due to inter-patient variability and rare pathological cases. In this paper, we propose Joint Retrieval Augmented Segmentation (J-RAS), a joint training method for guided image segmentation that integrates a segmentation model with a retrieval model. Both models are jointly optimized, enabling the segmentation model to leverage retrieved image-mask pairs to enrich its anatomical understanding, while the retrieval model learns segmentation-relevant features beyond simple visual similarity. This joint optimization ensures that retrieval actively contributes meaningful contextual cues to guide boundary delineation, thereby enhancing the overall segmentation performance. We validate J-RAS across multiple segmentation backbones, including U-Net, TransUNet, SAM, and SegFormer, on two benchmark datasets: ACDC and M&Ms, demonstrating consistent improvements. For example, on the ACDC dataset, SegFormer without J-RAS achieves a mean Dice score of 0.8708$\pm$0.042 and a mean Hausdorff Distance (HD) of 1.8130$\pm$2.49, whereas with J-RAS, the performance improves substantially to a mean Dice score of 0.9115$\pm$0.031 and a mean HD of 1.1489$\pm$0.30. These results highlight the method's effectiveness and its generalizability across architectures and datasets.

J-RAS: Enhancing Medical Image Segmentation via Retrieval-Augmented Joint Training

TL;DR

Medical image segmentation suffers from data scarcity and inter-patient variability, limiting generalization. J-RAS tackles this by jointly training a segmentation model with a retrieval model, using top- retrieved image–mask guides to provide contextual priors, and updating both networks in an end-to-end like loop with a shared segmentation loss. Across four backbones (U-Net, TransUNet, SegFormer, SAM) and two cardiac MRI datasets (ACDC and M&Ms), J-RAS yields consistent improvements in Dice and Hausdorff Distance, e.g., SegFormer with J-RAS improves mean Dice from to and HD from to , demonstrating enhanced boundary delineation and robustness. The approach reduces reliance on large annotated datasets and improves generalization, with practical impact for reliable, context-informed medical image analysis.

Abstract

Image segmentation, the process of dividing images into meaningful regions, is critical in medical applications for accurate diagnosis, treatment planning, and disease monitoring. Although manual segmentation by healthcare professionals produces precise outcomes, it is time-consuming, costly, and prone to variability due to differences in human expertise. Artificial intelligence (AI)-based methods have been developed to address these limitations by automating segmentation tasks; however, they often require large, annotated datasets that are rarely available in practice and frequently struggle to generalize across diverse imaging conditions due to inter-patient variability and rare pathological cases. In this paper, we propose Joint Retrieval Augmented Segmentation (J-RAS), a joint training method for guided image segmentation that integrates a segmentation model with a retrieval model. Both models are jointly optimized, enabling the segmentation model to leverage retrieved image-mask pairs to enrich its anatomical understanding, while the retrieval model learns segmentation-relevant features beyond simple visual similarity. This joint optimization ensures that retrieval actively contributes meaningful contextual cues to guide boundary delineation, thereby enhancing the overall segmentation performance. We validate J-RAS across multiple segmentation backbones, including U-Net, TransUNet, SAM, and SegFormer, on two benchmark datasets: ACDC and M&Ms, demonstrating consistent improvements. For example, on the ACDC dataset, SegFormer without J-RAS achieves a mean Dice score of 0.87080.042 and a mean Hausdorff Distance (HD) of 1.81302.49, whereas with J-RAS, the performance improves substantially to a mean Dice score of 0.91150.031 and a mean HD of 1.14890.30. These results highlight the method's effectiveness and its generalizability across architectures and datasets.

Paper Structure

This paper contains 22 sections, 5 equations, 12 figures, 4 tables, 1 algorithm.

Figures (12)

  • Figure 1: Overview of the proposed Joint Retrieval-Augmented Segmentation (J-RAS) method. Part A: Independent training phase. The retrieval model left is first fine-tuned to learn discriminative embeddings for images, while the segmentation model right is independently fine-tuned to predict anatomical masks. Part B: Joint training phase. Given a query image, the top-$K$ most similar guide images and their corresponding masks are retrieved from the knowledge base. The query image and retrieved guides are fused and passed to the segmentation model to predict the final mask. The segmentation loss computed on the predicted mask is then used to update both the segmentation and retrieval models simultaneously.
  • Figure 2: Mean Dice scores on the ACDC test set across different numbers of retrieved guide images and masks (Top-K = 1–10). The case of Top-K = 0 represents the baseline models without J-RAS, where SAM attains a score of 0.66.
  • Figure 3: Dice scores for Classes (RV, MYO, and LV) using Top-K retrieved images and masks ($K = 1$–$10$) on the ACDC test set with SegFormer, TransUNet, U-Net, and SAM. Bars represent class-wise performance. K = 0 corresponds to the baseline method.
  • Figure 4: Case-level analysis of improvements and degradations on the ACDC test set using the J-RAS method with SegFormer, TransUNet, SAM, and U-Net. The number of patients whose Dice scores improved or degraded compared to the baseline segmentation models is shown on the respective segment.
  • Figure 5: Top five most improved patients (out of 98) and the two degraded cases on the ACDC test set when using the J-RAS method with SegFormer. Bars indicate the magnitude of change in Dice score relative to the baseline model.
  • ...and 7 more figures