Table of Contents
Fetching ...

Causal Graph ODE: Continuous Treatment Effect Modeling in Multi-agent Dynamical Systems

Zijie Huang, Jeehyun Hwang, Junkai Zhang, Jinwoo Baik, Weitong Zhang, Dominik Wodarz, Yizhou Sun, Quanquan Gu, Wei Wang

TL;DR

CAG-ODE tackles the challenge of estimating continuous counterfactual outcomes in evolving multi-agent dynamical systems with multiple time-varying treatments. It extends Graph Ordinary Differential Equations by injecting time-aware treatment representations into two coupled ODEs governing nodes and edges, and it uses a treatment fusing module along with domain adversarial losses to counteract time-varying confounding and interference. The approach achieves state-of-the-art counterfactual prediction performance on both real (COVID-19) and simulated (tumor growth) datasets, with ablations confirming the importance of treatment balancing, interference balancing, and the attention-based treatment fusion. The work has practical implications for policy evaluation and medical decision-making in systems where agents and their interactions evolve continuously over time. Future directions include scalable edge inference for large graphs and modeling more complex, hierarchical, or competing treatment effects.

Abstract

Real-world multi-agent systems are often dynamic and continuous, where the agents co-evolve and undergo changes in their trajectories and interactions over time. For example, the COVID-19 transmission in the U.S. can be viewed as a multi-agent system, where states act as agents and daily population movements between them are interactions. Estimating the counterfactual outcomes in such systems enables accurate future predictions and effective decision-making, such as formulating COVID-19 policies. However, existing methods fail to model the continuous dynamic effects of treatments on the outcome, especially when multiple treatments (e.g., "stay-at-home" and "get-vaccine" policies) are applied simultaneously. To tackle this challenge, we propose Causal Graph Ordinary Differential Equations (CAG-ODE), a novel model that captures the continuous interaction among agents using a Graph Neural Network (GNN) as the ODE function. The key innovation of our model is to learn time-dependent representations of treatments and incorporate them into the ODE function, enabling precise predictions of potential outcomes. To mitigate confounding bias, we further propose two domain adversarial learning-based objectives, which enable our model to learn balanced continuous representations that are not affected by treatments or interference. Experiments on two datasets (i.e., COVID-19 and tumor growth) demonstrate the superior performance of our proposed model.

Causal Graph ODE: Continuous Treatment Effect Modeling in Multi-agent Dynamical Systems

TL;DR

CAG-ODE tackles the challenge of estimating continuous counterfactual outcomes in evolving multi-agent dynamical systems with multiple time-varying treatments. It extends Graph Ordinary Differential Equations by injecting time-aware treatment representations into two coupled ODEs governing nodes and edges, and it uses a treatment fusing module along with domain adversarial losses to counteract time-varying confounding and interference. The approach achieves state-of-the-art counterfactual prediction performance on both real (COVID-19) and simulated (tumor growth) datasets, with ablations confirming the importance of treatment balancing, interference balancing, and the attention-based treatment fusion. The work has practical implications for policy evaluation and medical decision-making in systems where agents and their interactions evolve continuously over time. Future directions include scalable edge inference for large graphs and modeling more complex, hierarchical, or competing treatment effects.

Abstract

Real-world multi-agent systems are often dynamic and continuous, where the agents co-evolve and undergo changes in their trajectories and interactions over time. For example, the COVID-19 transmission in the U.S. can be viewed as a multi-agent system, where states act as agents and daily population movements between them are interactions. Estimating the counterfactual outcomes in such systems enables accurate future predictions and effective decision-making, such as formulating COVID-19 policies. However, existing methods fail to model the continuous dynamic effects of treatments on the outcome, especially when multiple treatments (e.g., "stay-at-home" and "get-vaccine" policies) are applied simultaneously. To tackle this challenge, we propose Causal Graph Ordinary Differential Equations (CAG-ODE), a novel model that captures the continuous interaction among agents using a Graph Neural Network (GNN) as the ODE function. The key innovation of our model is to learn time-dependent representations of treatments and incorporate them into the ODE function, enabling precise predictions of potential outcomes. To mitigate confounding bias, we further propose two domain adversarial learning-based objectives, which enable our model to learn balanced continuous representations that are not affected by treatments or interference. Experiments on two datasets (i.e., COVID-19 and tumor growth) demonstrate the superior performance of our proposed model.
Paper Structure (31 sections, 14 equations, 3 figures, 3 tables, 1 algorithm)

This paper contains 31 sections, 14 equations, 3 figures, 3 tables, 1 algorithm.

Figures (3)

  • Figure 1: Overall Framework of CAG-ODE. The encoder first computes the latent initial states. Then the treatment-induced coupled ODE functions predict the continuous trajectories over time. Treatment representations learned through the fusing module are incorporated into the ODE functions to enable counterfactual prediction. Finally, the decoder outputs the predicted dynamics. Treatment and interference balancing losses are designed to ensure unbiased counterfactual predictions.
  • Figure 2: Case Study for changing different policies on the COVID-19 dataset.
  • Figure 3: Treatment Balancing Visualization on the COVID-19 Dataset.