Table of Contents
Fetching ...

Bayesian Meta-Learning for Improving Generalizability of Health Prediction Models With Similar Causal Mechanisms

Sophie Wharrie, Lisa Eick, Lotta Mäkinen, Andrea Ganna, Samuel Kaski, FinnGen

TL;DR

This work addresses the challenge of generalizing health prediction models trained on related tasks with non-i.i.d. data by introducing a Bayesian meta-learning framework that explicitly models similarity between the causal mechanisms of tasks. The core idea is to learn a global prior over model parameters and task-specific refinements, guided by a task similarity matrix derived from causal analyses (DAG-based, Mendelian Randomization, ICP, or chi-square). The approach aims to mitigate negative transfer during meta-training and improve adaptation to new patient populations, demonstrated through a stroke prediction case study using UK Biobank and FinnGen data, with substantial generalizability gains over baselines in several tasks. The results show that task-similarity–aware meta-learning can reduce shortcut learning and better capture intra-task variability, enabling more robust predictions across populations and suggesting broad applicability to other health prediction problems.

Abstract

Machine learning strategies like multi-task learning, meta-learning, and transfer learning enable efficient adaptation of machine learning models to specific applications in healthcare, such as prediction of various diseases, by leveraging generalizable knowledge across large datasets and multiple domains. In particular, Bayesian meta-learning methods pool data across related prediction tasks to learn prior distributions for model parameters, which are then used to derive models for specific tasks. However, inter- and intra-task variability due to disease heterogeneity and other patient-level differences pose challenges of negative transfer during shared learning and poor generalizability to new patients. We introduce a novel Bayesian meta-learning approach that aims to address this in two key settings: (1) predictions for new patients (same population as the training set) and (2) adapting to new patient populations. Our main contribution is in modeling similarity between causal mechanisms of the tasks, for (1) mitigating negative transfer during training and (2) fine-tuning that pools information from tasks that are expected to aid generalizability. We propose an algorithm for implementing this approach for Bayesian deep learning, and apply it to a case study for stroke prediction tasks using electronic health record data. Experiments for the UK Biobank dataset as the training population demonstrated significant generalizability improvements compared to standard meta-learning, non-causal task similarity measures, and local baselines (separate models for each task). This was assessed for a variety of tasks that considered both new patients from the training population (UK Biobank) and a new population (FinnGen).

Bayesian Meta-Learning for Improving Generalizability of Health Prediction Models With Similar Causal Mechanisms

TL;DR

This work addresses the challenge of generalizing health prediction models trained on related tasks with non-i.i.d. data by introducing a Bayesian meta-learning framework that explicitly models similarity between the causal mechanisms of tasks. The core idea is to learn a global prior over model parameters and task-specific refinements, guided by a task similarity matrix derived from causal analyses (DAG-based, Mendelian Randomization, ICP, or chi-square). The approach aims to mitigate negative transfer during meta-training and improve adaptation to new patient populations, demonstrated through a stroke prediction case study using UK Biobank and FinnGen data, with substantial generalizability gains over baselines in several tasks. The results show that task-similarity–aware meta-learning can reduce shortcut learning and better capture intra-task variability, enabling more robust predictions across populations and suggesting broad applicability to other health prediction problems.

Abstract

Machine learning strategies like multi-task learning, meta-learning, and transfer learning enable efficient adaptation of machine learning models to specific applications in healthcare, such as prediction of various diseases, by leveraging generalizable knowledge across large datasets and multiple domains. In particular, Bayesian meta-learning methods pool data across related prediction tasks to learn prior distributions for model parameters, which are then used to derive models for specific tasks. However, inter- and intra-task variability due to disease heterogeneity and other patient-level differences pose challenges of negative transfer during shared learning and poor generalizability to new patients. We introduce a novel Bayesian meta-learning approach that aims to address this in two key settings: (1) predictions for new patients (same population as the training set) and (2) adapting to new patient populations. Our main contribution is in modeling similarity between causal mechanisms of the tasks, for (1) mitigating negative transfer during training and (2) fine-tuning that pools information from tasks that are expected to aid generalizability. We propose an algorithm for implementing this approach for Bayesian deep learning, and apply it to a case study for stroke prediction tasks using electronic health record data. Experiments for the UK Biobank dataset as the training population demonstrated significant generalizability improvements compared to standard meta-learning, non-causal task similarity measures, and local baselines (separate models for each task). This was assessed for a variety of tasks that considered both new patients from the training population (UK Biobank) and a new population (FinnGen).
Paper Structure (17 sections, 5 equations, 5 figures, 6 tables, 1 algorithm)

This paper contains 17 sections, 5 equations, 5 figures, 6 tables, 1 algorithm.

Figures (5)

  • Figure 1: Problem setup and methods. Top left: General problem setup for task environment and data. Top right: Specific case study for stroke prediction. Bottom: Bayesian meta-learning with task similarity - a hierarchical model is used for deriving task-specific models ($\bm{\phi}_t$) from global ($\bm{\theta}$) parameters shared across all tasks and parameters ($\bm{\gamma}_t$) shared across related tasks, governed by task similarity weights ($\bm{w}$).
  • Figure 2: Task similarity is measured by mapping tasks to a space where similar tasks have small distances. Four methods are compared with different assumptions about the relationships between features, target variables, potential confounders, and additional variables that are assumed to aid identifiability of causal relationships: (1) Similarity in probabilistic graphical models, (2) Instrumental variables approach with Mendelian Randomization, (3) Invariant causal prediction across environments, and (4) Simple independence testing using the chi-square test.
  • Figure 3: Comparison of odds ratios for local baseline (blue), meta-learning baseline (red), and meta-learning with MR-based task similarity (green) for predicting different stroke types (G45, I60, I61, I62, I63, I64) with high confidence (above median entropy). The plots show odds ratios with 95% confidence intervals for various factors, which were selected based on their significance (p $<$ 0.001) and positive association (odds ratio $>$ 1) in at least one model.
  • Figure 4: Feature importance scores for different stroke types (G45, I60, I61, I62, I63, I64) calculated using the integrated gradients approach. Top row: Meta-learning baseline model. Bottom row: Meta-learning with task similarity, using the Mendelian Randomization (MR) approach. Features shown on the y-axis were ranked in the top 20 for at least one stroke type. The x-axis represents the importance score on a logarithmic scale.
  • Figure 5: Comparison of task similarity analyses using four different measures: Directed Acyclic Graph (DAG), Mendelian Randomization (MR), Invariant Causal Prediction (ICP), and Chi-square test (CHI2). Heatmaps on the left show pairwise task similarities (lighter colors indicate higher similarity) and graphs on the right visualize the relationships between tasks (edge inclusion defined as top 20% most similar tasks).