Table of Contents
Fetching ...

Expanding the Horizon: Enabling Hybrid Quantum Transfer Learning for Long-Tailed Chest X-Ray Classification

Skylar Chan, Pranav Kulkarni, Paul H. Yi, Vishwa S. Parekh

TL;DR

This work tackles long-tailed, multi-label chest X-ray classification by exploring hybrid quantum transfer learning via a Dressed Quantum Circuit (DQC) and a classical backbone. It introduces an open-source, Jax-based framework to efficiently simulate medium-sized qubit architectures on workstation hardware and evaluates scalability across 8, 14, and 19 disease labels using NIH-CXR-LT and MIMIC-CXR-LT datasets. Key findings show substantial wall-clock speedups for the quantum workflow over PyTorch and TensorFlow, but the DQC under current settings yields slower convergence and slightly lower AUROC than classical CDL on internal NIH data and external MIMIC data, with a smaller gap on the external set. The work demonstrates the practicality of QML for medical imaging while identifying hyperparameter and architectural improvements needed to bridge the performance gap and enable broader clinical impact.

Abstract

Quantum machine learning (QML) has the potential for improving the multi-label classification of rare, albeit critical, diseases in large-scale chest x-ray (CXR) datasets due to theoretical quantum advantages over classical machine learning (CML) in sample efficiency and generalizability. While prior literature has explored QML with CXRs, it has focused on binary classification tasks with small datasets due to limited access to quantum hardware and computationally expensive simulations. To that end, we implemented a Jax-based framework that enables the simulation of medium-sized qubit architectures with significant improvements in wall-clock time over current software offerings. We evaluated the performance of our Jax-based framework in terms of efficiency and performance for hybrid quantum transfer learning for long-tailed classification across 8, 14, and 19 disease labels using large-scale CXR datasets. The Jax-based framework resulted in up to a 58% and 95% speed-up compared to PyTorch and TensorFlow implementations, respectively. However, compared to CML, QML demonstrated slower convergence and an average AUROC of 0.70, 0.73, and 0.74 for the classification of 8, 14, and 19 CXR disease labels. In comparison, the CML models had an average AUROC of 0.77, 0.78, and 0.80 respectively. In conclusion, our work presents an accessible implementation of hybrid quantum transfer learning for long-tailed CXR classification with a computationally efficient Jax-based framework.

Expanding the Horizon: Enabling Hybrid Quantum Transfer Learning for Long-Tailed Chest X-Ray Classification

TL;DR

This work tackles long-tailed, multi-label chest X-ray classification by exploring hybrid quantum transfer learning via a Dressed Quantum Circuit (DQC) and a classical backbone. It introduces an open-source, Jax-based framework to efficiently simulate medium-sized qubit architectures on workstation hardware and evaluates scalability across 8, 14, and 19 disease labels using NIH-CXR-LT and MIMIC-CXR-LT datasets. Key findings show substantial wall-clock speedups for the quantum workflow over PyTorch and TensorFlow, but the DQC under current settings yields slower convergence and slightly lower AUROC than classical CDL on internal NIH data and external MIMIC data, with a smaller gap on the external set. The work demonstrates the practicality of QML for medical imaging while identifying hyperparameter and architectural improvements needed to bridge the performance gap and enable broader clinical impact.

Abstract

Quantum machine learning (QML) has the potential for improving the multi-label classification of rare, albeit critical, diseases in large-scale chest x-ray (CXR) datasets due to theoretical quantum advantages over classical machine learning (CML) in sample efficiency and generalizability. While prior literature has explored QML with CXRs, it has focused on binary classification tasks with small datasets due to limited access to quantum hardware and computationally expensive simulations. To that end, we implemented a Jax-based framework that enables the simulation of medium-sized qubit architectures with significant improvements in wall-clock time over current software offerings. We evaluated the performance of our Jax-based framework in terms of efficiency and performance for hybrid quantum transfer learning for long-tailed classification across 8, 14, and 19 disease labels using large-scale CXR datasets. The Jax-based framework resulted in up to a 58% and 95% speed-up compared to PyTorch and TensorFlow implementations, respectively. However, compared to CML, QML demonstrated slower convergence and an average AUROC of 0.70, 0.73, and 0.74 for the classification of 8, 14, and 19 CXR disease labels. In comparison, the CML models had an average AUROC of 0.77, 0.78, and 0.80 respectively. In conclusion, our work presents an accessible implementation of hybrid quantum transfer learning for long-tailed CXR classification with a computationally efficient Jax-based framework.
Paper Structure (24 sections, 3 equations, 13 figures, 3 tables)

This paper contains 24 sections, 3 equations, 13 figures, 3 tables.

Figures (13)

  • Figure 1: Classical deep learning model. Image features are extracted with ResNet50 and preprocessed with a linear layer before obtaining predictions.
  • Figure 2: Dressed quantum circuit model. Image features are extracted with ResNet50, preprocessed down to size with a linear layer, then embedded into the quantum circuit with angle encoding applied to a 50/50 superposition of $\lvert {0} \rangle$ and $\lvert {1} \rangle$ after the Hadamard gate. Variational parameters (yellow) and CNOT gates (white) are applied, and measurements are fed into the classical postprocessing layer to obtain predictions. Preprocessing layers are highlighted in blue, variational quantum parameters in yellow, and postprocessing layers in orange.
  • Figure 3: Dataloader implemented for experiments. Preprocessed images are serialized, compressed, and cached on disk ahead of time. At runtime, they are retrieved from the cache, decompressed, and deserialized. Image augmentations occur at runtime. A SQLite database tracks the file paths of the cached images.
  • Figure 4: Training and validation loss between CDL and DQC models across CXR-8, CXR-14, and CXR-19 multi-label classification tasks.
  • Figure 5: Mean AUROC between CDL and DQC models on NIH-CXR-LT and MIMIC-CXR-LT test sets across CXR-8, CXR-14, and CXR-19 classification tasks. (ns: $p>0.05$, *: $p<0.05$, **: $p<0.01$, ***: $p<0.001$).
  • ...and 8 more figures