Table of Contents
Fetching ...

Amortized Active Causal Induction with Deep Reinforcement Learning

Yashas Annadani, Panagiotis Tigas, Stefan Bauer, Adam Foster

TL;DR

The paper tackles sample-efficient causal structure learning under interventions without access to data likelihood by proposing CAASL, a transformer-based amortized policy that actively designs interventions in a HiP-MDP via reinforcement learning. It leverages an AVICI-based reward to guide intervention selection, enabling real-time adaptation and strong zero-shot generalization to higher dimensions and unseen intervention types. Empirical results in synthetic linear SCMs and the SERGIO single-cell gene-regulatory simulator show improved causal graph recovery and robust performance under distribution shifts. The approach connects to sequential Bayesian experimental design through information-gain bounds and offers a practical, scalable path for lab-in-the-loop experimentation in complex biological systems.

Abstract

We present Causal Amortized Active Structure Learning (CAASL), an active intervention design policy that can select interventions that are adaptive, real-time and that does not require access to the likelihood. This policy, an amortized network based on the transformer, is trained with reinforcement learning on a simulator of the design environment, and a reward function that measures how close the true causal graph is to a causal graph posterior inferred from the gathered data. On synthetic data and a single-cell gene expression simulator, we demonstrate empirically that the data acquired through our policy results in a better estimate of the underlying causal graph than alternative strategies. Our design policy successfully achieves amortized intervention design on the distribution of the training environment while also generalizing well to distribution shifts in test-time design environments. Further, our policy also demonstrates excellent zero-shot generalization to design environments with dimensionality higher than that during training, and to intervention types that it has not been trained on.

Amortized Active Causal Induction with Deep Reinforcement Learning

TL;DR

The paper tackles sample-efficient causal structure learning under interventions without access to data likelihood by proposing CAASL, a transformer-based amortized policy that actively designs interventions in a HiP-MDP via reinforcement learning. It leverages an AVICI-based reward to guide intervention selection, enabling real-time adaptation and strong zero-shot generalization to higher dimensions and unseen intervention types. Empirical results in synthetic linear SCMs and the SERGIO single-cell gene-regulatory simulator show improved causal graph recovery and robust performance under distribution shifts. The approach connects to sequential Bayesian experimental design through information-gain bounds and offers a practical, scalable path for lab-in-the-loop experimentation in complex biological systems.

Abstract

We present Causal Amortized Active Structure Learning (CAASL), an active intervention design policy that can select interventions that are adaptive, real-time and that does not require access to the likelihood. This policy, an amortized network based on the transformer, is trained with reinforcement learning on a simulator of the design environment, and a reward function that measures how close the true causal graph is to a causal graph posterior inferred from the gathered data. On synthetic data and a single-cell gene expression simulator, we demonstrate empirically that the data acquired through our policy results in a better estimate of the underlying causal graph than alternative strategies. Our design policy successfully achieves amortized intervention design on the distribution of the training environment while also generalizing well to distribution shifts in test-time design environments. Further, our policy also demonstrates excellent zero-shot generalization to design environments with dimensionality higher than that during training, and to intervention types that it has not been trained on.
Paper Structure (53 sections, 12 equations, 18 figures, 2 tables)

This paper contains 53 sections, 12 equations, 18 figures, 2 tables.

Figures (18)

  • Figure 1: Causal Amortized Structure Learning (CAASL) is an active intervention design method that directly proposes the next intervention to perform by just a forward-pass of the transformer based policy.
  • Figure 2: Visualization of the rollout of the trained CAASL policy on a randomly sampled environment with $n_0=50$ initial observational samples. Colored circles indicate nodes with a $\mathrm{do}$ intervention. The policy selects interventions that mostly correspond to the variables with a child in the ground truth graph. At $t=2$, the policy selects the only child $y_1$, which breaks all direct causal effects. This gives lesser information about the overall causal model. After this, $y_1$ is never chosen. Initially, the policy is exploratory wrt targets and exploitative wrt values. This trend is reversed as the episode progresses. The policy is trained on environments with $d=2$, therefore it has not seen any graphs with $d=3$ before.
  • Figure 3: Amortization results of various intervention strategies on 100 random test environments. CAASL significantly outperforms other intervention strategies. Shaded area represents 95% CI.
  • Figure 4: Zero-shot OOD returns of CAASL on 100 random environments with distribution shift coming from (a) graphs, (b) graphs and mechanisms, (c) graphs, mechanisms and noise, (d) noise changes from homoskedastic to heteroskedastic, and finally (e) intervention changes from $\mathrm{do}$ to a shift intervention. CAASL outperforms other intervention strategies. Shaded area represents 95% CI.
  • Figure 5: Zero-Shot OOD generalization results when dimensionality $d$ changes for synthetic environment. For training, $d=10$. Left: Zero-Shot test returns with $d=20$. Right: Relative mean zero-shot returns of CAASL wrt random for different $d$. Results on 100 random environments. Shaded area represents 95% CI.
  • ...and 13 more figures