Table of Contents
Fetching ...

Causal Representation Learning with Observational Grouping for CXR Classification

Rajat Rasal, Avinash Kori, Ben Glocker

TL;DR

This work tackles distribution shifts in chest X-ray disease classification by learning identifiable, causal representations through observational grouping. It introduces an invariant loss $L_{INV}$ that, together with the supervised loss $L_{BCE}$, encourages a latent space that is invariant across groups (e.g., sex, race, imaging view) while preserving discriminative power, and it provides identifiability guarantees for the learned content. The authors validate the approach across CheXpert and MIMIC datasets and multiple architectures, showing improved AUROC and reduced latent variability, implying better generalisability and robustness. They discuss practical implications for fairness and reliability in clinical AI, and provide code to enable broader adoption. Overall, the method offers a principled path to robust, interpretable CXR classification under real-world non-IID conditions, with identifiable, causal latent representations as a key asset.

Abstract

Identifiable causal representation learning seeks to uncover the true causal relationships underlying a data generation process. In medical imaging, this presents opportunities to improve the generalisability and robustness of task-specific latent features. This work introduces the concept of grouping observations to learn identifiable representations for disease classification in chest X-rays via an end-to-end framework. Our experiments demonstrate that these causal representations improve generalisability and robustness across multiple classification tasks when grouping is used to enforce invariance w.r.t race, sex, and imaging views.

Causal Representation Learning with Observational Grouping for CXR Classification

TL;DR

This work tackles distribution shifts in chest X-ray disease classification by learning identifiable, causal representations through observational grouping. It introduces an invariant loss that, together with the supervised loss , encourages a latent space that is invariant across groups (e.g., sex, race, imaging view) while preserving discriminative power, and it provides identifiability guarantees for the learned content. The authors validate the approach across CheXpert and MIMIC datasets and multiple architectures, showing improved AUROC and reduced latent variability, implying better generalisability and robustness. They discuss practical implications for fairness and reliability in clinical AI, and provide code to enable broader adoption. Overall, the method offers a principled path to robust, interpretable CXR classification under real-world non-IID conditions, with identifiable, causal latent representations as a key asset.

Abstract

Identifiable causal representation learning seeks to uncover the true causal relationships underlying a data generation process. In medical imaging, this presents opportunities to improve the generalisability and robustness of task-specific latent features. This work introduces the concept of grouping observations to learn identifiable representations for disease classification in chest X-rays via an end-to-end framework. Our experiments demonstrate that these causal representations improve generalisability and robustness across multiple classification tasks when grouping is used to enforce invariance w.r.t race, sex, and imaging views.

Paper Structure

This paper contains 21 sections, 4 equations, 2 figures, 3 tables, 1 algorithm.

Figures (2)

  • Figure 1: Training with observational grouping organises the latent space such that representations are invariant to properties across groups in $G$. We select two groups randomly $k, k' \sim [K]$ and sample images $x_k \sim G_k$ and $x_{k'} \sim G_{k'}$. We use $\phi$ to extract features, and jointly incorporate them into an invariant loss ($\mathcal{L}_{\mathrm{INV}}$) and a binary classification loss ($\mathcal{L}_{\mathrm{BCE}}$). The invariant loss ($\mathcal{L}_{\mathrm{INV}}$) structures the latent space by relaxing the IID assumption, leading to a transformation from $\mathcal{Z}'$ to $\mathcal{Z}$, as illustrated. Here, the blue coloured regions indicate the theoretical $\phi$-supported region of the distribution before and after the use of $\mathcal{L}_{\mathrm{INV}}$. The resulting representation is used for classification with $\psi$.
  • Figure 2: Density estimation on scaled PC1 embeddings of invariant and non-invariant latent features from $\phi$, implemented with a DenseNet backbone, for no findings vs pleural effusion classification.

Theorems & Definitions (4)

  • definition thmcounterdefinition
  • definition thmcounterdefinition
  • definition thmcounterdefinition
  • definition thmcounterdefinition