Bayesian Meta-Learning for Improving Generalizability of Health Prediction Models With Similar Causal Mechanisms
Sophie Wharrie, Lisa Eick, Lotta Mäkinen, Andrea Ganna, Samuel Kaski, FinnGen
TL;DR
This work addresses the challenge of generalizing health prediction models trained on related tasks with non-i.i.d. data by introducing a Bayesian meta-learning framework that explicitly models similarity between the causal mechanisms of tasks. The core idea is to learn a global prior over model parameters and task-specific refinements, guided by a task similarity matrix derived from causal analyses (DAG-based, Mendelian Randomization, ICP, or chi-square). The approach aims to mitigate negative transfer during meta-training and improve adaptation to new patient populations, demonstrated through a stroke prediction case study using UK Biobank and FinnGen data, with substantial generalizability gains over baselines in several tasks. The results show that task-similarity–aware meta-learning can reduce shortcut learning and better capture intra-task variability, enabling more robust predictions across populations and suggesting broad applicability to other health prediction problems.
Abstract
Machine learning strategies like multi-task learning, meta-learning, and transfer learning enable efficient adaptation of machine learning models to specific applications in healthcare, such as prediction of various diseases, by leveraging generalizable knowledge across large datasets and multiple domains. In particular, Bayesian meta-learning methods pool data across related prediction tasks to learn prior distributions for model parameters, which are then used to derive models for specific tasks. However, inter- and intra-task variability due to disease heterogeneity and other patient-level differences pose challenges of negative transfer during shared learning and poor generalizability to new patients. We introduce a novel Bayesian meta-learning approach that aims to address this in two key settings: (1) predictions for new patients (same population as the training set) and (2) adapting to new patient populations. Our main contribution is in modeling similarity between causal mechanisms of the tasks, for (1) mitigating negative transfer during training and (2) fine-tuning that pools information from tasks that are expected to aid generalizability. We propose an algorithm for implementing this approach for Bayesian deep learning, and apply it to a case study for stroke prediction tasks using electronic health record data. Experiments for the UK Biobank dataset as the training population demonstrated significant generalizability improvements compared to standard meta-learning, non-causal task similarity measures, and local baselines (separate models for each task). This was assessed for a variety of tasks that considered both new patients from the training population (UK Biobank) and a new population (FinnGen).
