Table of Contents
Fetching ...

Federated Learning for Estimating Heterogeneous Treatment Effects

Disha Makhija, Joydeep Ghosh, Yejin Kim

TL;DR

This work proposes a novel framework for collaborative learning of HTE estimators across institutions via Federated Learning, and shows that even under a diversity of interventions and subject populations across clients, one can jointly learn a common feature representation while concurrently and privately learning the specific predictive functions for outcomes under distinct interventions across institutions.

Abstract

Machine learning methods for estimating heterogeneous treatment effects (HTE) facilitate large-scale personalized decision-making across various domains such as healthcare, policy making, education, and more. Current machine learning approaches for HTE require access to substantial amounts of data per treatment, and the high costs associated with interventions makes centrally collecting so much data for each intervention a formidable challenge. To overcome this obstacle, in this work, we propose a novel framework for collaborative learning of HTE estimators across institutions via Federated Learning. We show that even under a diversity of interventions and subject populations across clients, one can jointly learn a common feature representation, while concurrently and privately learning the specific predictive functions for outcomes under distinct interventions across institutions. Our framework and the associated algorithm are based on this insight, and leverage tabular transformers to map multiple input data to feature representations which are then used for outcome prediction via multi-task learning. We also propose a novel way of federated training of personalised transformers that can work with heterogeneous input feature spaces. Experimental results on real-world clinical trial data demonstrate the effectiveness of our method.

Federated Learning for Estimating Heterogeneous Treatment Effects

TL;DR

This work proposes a novel framework for collaborative learning of HTE estimators across institutions via Federated Learning, and shows that even under a diversity of interventions and subject populations across clients, one can jointly learn a common feature representation while concurrently and privately learning the specific predictive functions for outcomes under distinct interventions across institutions.

Abstract

Machine learning methods for estimating heterogeneous treatment effects (HTE) facilitate large-scale personalized decision-making across various domains such as healthcare, policy making, education, and more. Current machine learning approaches for HTE require access to substantial amounts of data per treatment, and the high costs associated with interventions makes centrally collecting so much data for each intervention a formidable challenge. To overcome this obstacle, in this work, we propose a novel framework for collaborative learning of HTE estimators across institutions via Federated Learning. We show that even under a diversity of interventions and subject populations across clients, one can jointly learn a common feature representation, while concurrently and privately learning the specific predictive functions for outcomes under distinct interventions across institutions. Our framework and the associated algorithm are based on this insight, and leverage tabular transformers to map multiple input data to feature representations which are then used for outcome prediction via multi-task learning. We also propose a novel way of federated training of personalised transformers that can work with heterogeneous input feature spaces. Experimental results on real-world clinical trial data demonstrate the effectiveness of our method.
Paper Structure (22 sections, 8 equations, 5 figures, 6 tables)

This paper contains 22 sections, 8 equations, 5 figures, 6 tables.

Figures (5)

  • Figure 1: An overview of the setting with N locations, where each location has some overlap with the other locations in their non-identical feature as well as the treatment space.
  • Figure 2: The above figure illustrates the FedTransTEE framework, with $N$ site locations on the left; as depicted, the covariate encoder, the treatment encoder and the cross-attention module are shared across locations but the predictor is personalised for each location. A detailed view of the model architecture with specifics of each component is shown on the right side.
  • Figure 3: Visualization of the activations of cross-attention head obtained while learning on the ICH dataset.
  • Figure 4: Visualization of the activations of first self-attention head of the covariate encoder obtained while learning on the ICH dataset.
  • Figure 5: Visualization of the activations of second self-attention head of the covariate encoder obtained while learning on the ICH dataset.