Table of Contents
Fetching ...

Prompt-driven Latent Domain Generalization for Medical Image Classification

Siyuan Yan, Chi Liu, Zhen Yu, Lie Ju, Dwarikanath Mahapatra, Brigid Betz-Stablein, Victoria Mar, Monika Janda, Peter Soyer, Zongyuan Ge

TL;DR

Medical imaging models suffer from distribution shifts and often lack reliable domain labels, limiting traditional domain generalization. The authors propose PLDG, a prompt-driven latent domain generalization framework that first discovers pseudo-domain groups via Simplicity Bias on shallow ViT style features, then learns domain-specific prompts with cross-domain sharing through a domain prompt generator and a domain mixup strategy. The approach is implemented on Vision Transformer backbones with an adapter-based weighting mechanism and achieves competitive or superior performance to domain-label DG methods across melanoma, diabetic retinopathy, histopathology, and trap debiasing tasks, while remaining effective without domain labels. This work demonstrates that latent DG with prompt-based learning can robustly generalize to unseen medical domains, offering practical benefits when domain annotations are unavailable or unreliable, and it includes public release of code.

Abstract

Deep learning models for medical image analysis easily suffer from distribution shifts caused by dataset artifacts bias, camera variations, differences in the imaging station, etc., leading to unreliable diagnoses in real-world clinical settings. Domain generalization (DG) methods, which aim to train models on multiple domains to perform well on unseen domains, offer a promising direction to solve the problem. However, existing DG methods assume domain labels of each image are available and accurate, which is typically feasible for only a limited number of medical datasets. To address these challenges, we propose a novel DG framework for medical image classification without relying on domain labels, called Prompt-driven Latent Domain Generalization (PLDG). PLDG consists of unsupervised domain discovery and prompt learning. This framework first discovers pseudo domain labels by clustering the bias-associated style features, then leverages collaborative domain prompts to guide a Vision Transformer to learn knowledge from discovered diverse domains. To facilitate cross-domain knowledge learning between different prompts, we introduce a domain prompt generator that enables knowledge sharing between domain prompts and a shared prompt. A domain mixup strategy is additionally employed for more flexible decision margins and mitigates the risk of incorrect domain assignments. Extensive experiments on three medical image classification tasks and one debiasing task demonstrate that our method can achieve comparable or even superior performance than conventional DG algorithms without relying on domain labels. Our code will be publicly available upon the paper is accepted.

Prompt-driven Latent Domain Generalization for Medical Image Classification

TL;DR

Medical imaging models suffer from distribution shifts and often lack reliable domain labels, limiting traditional domain generalization. The authors propose PLDG, a prompt-driven latent domain generalization framework that first discovers pseudo-domain groups via Simplicity Bias on shallow ViT style features, then learns domain-specific prompts with cross-domain sharing through a domain prompt generator and a domain mixup strategy. The approach is implemented on Vision Transformer backbones with an adapter-based weighting mechanism and achieves competitive or superior performance to domain-label DG methods across melanoma, diabetic retinopathy, histopathology, and trap debiasing tasks, while remaining effective without domain labels. This work demonstrates that latent DG with prompt-based learning can robustly generalize to unseen medical domains, offering practical benefits when domain annotations are unavailable or unreliable, and it includes public release of code.

Abstract

Deep learning models for medical image analysis easily suffer from distribution shifts caused by dataset artifacts bias, camera variations, differences in the imaging station, etc., leading to unreliable diagnoses in real-world clinical settings. Domain generalization (DG) methods, which aim to train models on multiple domains to perform well on unseen domains, offer a promising direction to solve the problem. However, existing DG methods assume domain labels of each image are available and accurate, which is typically feasible for only a limited number of medical datasets. To address these challenges, we propose a novel DG framework for medical image classification without relying on domain labels, called Prompt-driven Latent Domain Generalization (PLDG). PLDG consists of unsupervised domain discovery and prompt learning. This framework first discovers pseudo domain labels by clustering the bias-associated style features, then leverages collaborative domain prompts to guide a Vision Transformer to learn knowledge from discovered diverse domains. To facilitate cross-domain knowledge learning between different prompts, we introduce a domain prompt generator that enables knowledge sharing between domain prompts and a shared prompt. A domain mixup strategy is additionally employed for more flexible decision margins and mitigates the risk of incorrect domain assignments. Extensive experiments on three medical image classification tasks and one debiasing task demonstrate that our method can achieve comparable or even superior performance than conventional DG algorithms without relying on domain labels. Our code will be publicly available upon the paper is accepted.
Paper Structure (25 sections, 7 equations, 8 figures, 5 tables)

This paper contains 25 sections, 7 equations, 8 figures, 5 tables.

Figures (8)

  • Figure 1: The comparison between conventional domain generalization (DG) and our latent domain generalization. Conventional DG aims to train the model to learn from multiple domains to generalize well in unseen domains. Latent domain generalization aims to automatically discover essential domain information from a training set, enabling the training of a DG algorithm capable of generalizing to unseen domains.
  • Figure 2: Illustration of our prompt-driven latent domain generalization (PLDG) algorithm, (a) We perform one-time clustering on the CLS token from the shallow layer of the ViT model to discover the bias-related pseudo domain labels (see \ref{['seca']}). (b) Train a domain prompt-based ViT to learn domain-specific knowledge for unseen domain prediction (see \ref{['secb']}). A domain prompt generator is further employed to facilitate cross-domain knowledge learning (see \ref{['secc']}).
  • Figure 3: Illustration of (a) domain prompt generator and (b) domain Mixup strategy.
  • Figure 4: Ablation analysis of (a) prompt length and (b) cluster number on six datasets of two tasks.
  • Figure 5: The relationship analysis of domain prompt weights and domain distance for our method with domain labels (a) and without domain labels (b).
  • ...and 3 more figures