Multi-Dataset Multi-Task Learning for COVID-19 Prognosis
Filippo Ruffini, Lorenzo Tronchin, Zhuoru Wu, Wenting Chen, Paolo Soda, Linlin Shen, Valerio Guarrasi
TL;DR
The paper tackles prognosis prediction for COVID-19 from chest radiographs in the face of scarce labeled data. It introduces a multi-dataset multi-task (MDMT) framework that jointly learns from two public datasets with distinct labeling schemes by using a shared backbone and task-specific heads guided by an indicator loss. The approach, evaluated across 18 CNN backbones and rigorous validation, shows significant improvements over single-task and transfer-learning baselines, particularly for prognosis classification, and demonstrates enhanced robustness across clinical settings. This work highlights the value of multi-source data integration for improving the generalization of imaging-based prognostic models and points toward future extensions with explainable AI and multi-modal data fusion.
Abstract
In the fight against the COVID-19 pandemic, leveraging artificial intelligence to predict disease outcomes from chest radiographic images represents a significant scientific aim. The challenge, however, lies in the scarcity of large, labeled datasets with compatible tasks for training deep learning models without leading to overfitting. Addressing this issue, we introduce a novel multi-dataset multi-task training framework that predicts COVID-19 prognostic outcomes from chest X-rays (CXR) by integrating correlated datasets from disparate sources, distant from conventional multi-task learning approaches, which rely on datasets with multiple and correlated labeling schemes. Our framework hypothesizes that assessing severity scores enhances the model's ability to classify prognostic severity groups, thereby improving its robustness and predictive power. The proposed architecture comprises a deep convolutional network that receives inputs from two publicly available CXR datasets, AIforCOVID for severity prognostic prediction and BRIXIA for severity score assessment, and branches into task-specific fully connected output networks. Moreover, we propose a multi-task loss function, incorporating an indicator function, to exploit multi-dataset integration. The effectiveness and robustness of the proposed approach are demonstrated through significant performance improvements in prognosis classification tasks across 18 different convolutional neural network backbones in different evaluation strategies. This improvement is evident over single-task baselines and standard transfer learning strategies, supported by extensive statistical analysis, showing great application potential.
