Expanding the Horizon: Enabling Hybrid Quantum Transfer Learning for Long-Tailed Chest X-Ray Classification
Skylar Chan, Pranav Kulkarni, Paul H. Yi, Vishwa S. Parekh
TL;DR
This work tackles long-tailed, multi-label chest X-ray classification by exploring hybrid quantum transfer learning via a Dressed Quantum Circuit (DQC) and a classical backbone. It introduces an open-source, Jax-based framework to efficiently simulate medium-sized qubit architectures on workstation hardware and evaluates scalability across 8, 14, and 19 disease labels using NIH-CXR-LT and MIMIC-CXR-LT datasets. Key findings show substantial wall-clock speedups for the quantum workflow over PyTorch and TensorFlow, but the DQC under current settings yields slower convergence and slightly lower AUROC than classical CDL on internal NIH data and external MIMIC data, with a smaller gap on the external set. The work demonstrates the practicality of QML for medical imaging while identifying hyperparameter and architectural improvements needed to bridge the performance gap and enable broader clinical impact.
Abstract
Quantum machine learning (QML) has the potential for improving the multi-label classification of rare, albeit critical, diseases in large-scale chest x-ray (CXR) datasets due to theoretical quantum advantages over classical machine learning (CML) in sample efficiency and generalizability. While prior literature has explored QML with CXRs, it has focused on binary classification tasks with small datasets due to limited access to quantum hardware and computationally expensive simulations. To that end, we implemented a Jax-based framework that enables the simulation of medium-sized qubit architectures with significant improvements in wall-clock time over current software offerings. We evaluated the performance of our Jax-based framework in terms of efficiency and performance for hybrid quantum transfer learning for long-tailed classification across 8, 14, and 19 disease labels using large-scale CXR datasets. The Jax-based framework resulted in up to a 58% and 95% speed-up compared to PyTorch and TensorFlow implementations, respectively. However, compared to CML, QML demonstrated slower convergence and an average AUROC of 0.70, 0.73, and 0.74 for the classification of 8, 14, and 19 CXR disease labels. In comparison, the CML models had an average AUROC of 0.77, 0.78, and 0.80 respectively. In conclusion, our work presents an accessible implementation of hybrid quantum transfer learning for long-tailed CXR classification with a computationally efficient Jax-based framework.
