Causal Representation Learning with Observational Grouping for CXR Classification
Rajat Rasal, Avinash Kori, Ben Glocker
TL;DR
This work tackles distribution shifts in chest X-ray disease classification by learning identifiable, causal representations through observational grouping. It introduces an invariant loss $L_{INV}$ that, together with the supervised loss $L_{BCE}$, encourages a latent space that is invariant across groups (e.g., sex, race, imaging view) while preserving discriminative power, and it provides identifiability guarantees for the learned content. The authors validate the approach across CheXpert and MIMIC datasets and multiple architectures, showing improved AUROC and reduced latent variability, implying better generalisability and robustness. They discuss practical implications for fairness and reliability in clinical AI, and provide code to enable broader adoption. Overall, the method offers a principled path to robust, interpretable CXR classification under real-world non-IID conditions, with identifiable, causal latent representations as a key asset.
Abstract
Identifiable causal representation learning seeks to uncover the true causal relationships underlying a data generation process. In medical imaging, this presents opportunities to improve the generalisability and robustness of task-specific latent features. This work introduces the concept of grouping observations to learn identifiable representations for disease classification in chest X-rays via an end-to-end framework. Our experiments demonstrate that these causal representations improve generalisability and robustness across multiple classification tasks when grouping is used to enforce invariance w.r.t race, sex, and imaging views.
