CSAI: Conditional Self-Attention Imputation for Healthcare Time-series
Linglong Qian, Joseph Arul Raj, Hugh Logan Ellis, Ao Zhang, Yuezhou Zhang, Tao Wang, Richard JB Dobson, Zina Ibrahim
TL;DR
CSAI addresses the challenge of missing data in healthcare time series by extending BRITS with three mechanisms: an attention-based conditional hidden-state initialisation to capture long-range dynamics, a domain-informed temporal decay that aligns with feature-specific clinical recording patterns, and a non-uniform masking strategy that reflects non-random missingness. The method achieves superior imputation accuracy and downstream predictive performance on multiple EHR benchmarks, with ablations confirming each component's contribution. Integrated into PyPOTS, CSAI provides a practical, open-source tool for reliable imputation and prognosis in clinical data, offering potential for improvements in tasks like sepsis detection and outcome prediction. Overall, CSAI advances neural imputation for EHRs by aligning algorithmic techniques with the realities of clinical data collection and missingness patterns.
Abstract
We introduce the Conditional Self-Attention Imputation (CSAI) model, a novel recurrent neural network architecture designed to address the challenges of complex missing data patterns in multivariate time series derived from hospital electronic health records (EHRs). CSAI extends state-of-the-art neural network-based imputation by introducing key modifications specific to EHR data: a) attention-based hidden state initialisation to capture both long- and short-range temporal dependencies prevalent in EHRs, b) domain-informed temporal decay to mimic clinical data recording patterns, and c) a non-uniform masking strategy that models non-random missingness by calibrating weights according to both temporal and cross-sectional data characteristics. Comprehensive evaluation across four EHR benchmark datasets demonstrates CSAI's effectiveness compared to state-of-the-art architectures in data restoration and downstream tasks. CSAI is integrated into PyPOTS, an open-source Python toolbox designed for machine learning tasks on partially observed time series. This work significantly advances the state of neural network imputation applied to EHRs by more closely aligning algorithmic imputation with clinical realities.
