Stride-Net: Fairness-Aware Disentangled Representation Learning for Chest X-Ray Diagnosis
Darakshan Rashid, Raza Imam, Dwarikanath Mahapatra, Brejesh Lall
TL;DR
Addresses fairness gaps in chest X-ray diagnosis across demographic subgroups. Proposes Stride-Net, which learns label-aligned, patch-level representations using a learnable stride mask, BioBERT-based disease label embeddings, Group-Optimal Transport alignment, and adversarial suppression of sensitive attributes. The approach optimizes a joint loss $\mathcal{L}_{total} = \mathcal{L}_c + \alpha \mathcal{L}_{GOT} + \beta \mathcal{L}_s - \gamma \mathcal{L}_{conf}$ with $\lambda \in [0,1]$, enabling end-to-end training that emphasizes clinically meaningful regions while reducing demographic leakage. Empirical results on MIMIC-CXR and CheXpert show improved fairness metrics (PQD and EOM) while maintaining or improving diagnostic accuracy across race and race–gender subgroups, including challenging intersectional cases, and the method outperforms existing debiasing baselines. The work advances fairness-aware representation learning in medical imaging and offers publicly available code for reproducibility.
Abstract
Deep neural networks for chest X-ray classification achieve strong average performance, yet often underperform for specific demographic subgroups, raising critical concerns about clinical safety and equity. Existing debiasing methods frequently yield inconsistent improvements across datasets or attain fairness by degrading overall diagnostic utility, treating fairness as a post hoc constraint rather than a property of the learned representation. In this work, we propose Stride-Net (Sensitive Attribute Resilient Learning via Disentanglement and Learnable Masking with Embedding Alignment), a fairness-aware framework that learns disease-discriminative yet demographically invariant representations for chest X-ray analysis. Stride-Net operates at the patch level, using a learnable stride-based mask to select label-aligned image regions while suppressing sensitive attribute information through adversarial confusion loss. To anchor representations in clinical semantics and discourage shortcut learning, we further enforce semantic alignment between image features and BioBERT-based disease label embeddings via Group Optimal Transport. We evaluate Stride-Net on the MIMIC-CXR and CheXpert benchmarks across race and intersectional race-gender subgroups. Across architectures including ResNet and Vision Transformers, Stride-Net consistently improves fairness metrics while matching or exceeding baseline accuracy, achieving a more favorable accuracy-fairness trade-off than prior debiasing approaches. Our code is available at https://github.com/Daraksh/Fairness_StrideNet.
