Table of Contents
Fetching ...

A Meta-Learning Approach to Bayesian Causal Discovery

Anish Dhir, Matthew Ashman, James Requeima, Mark van der Wilk

TL;DR

The paper tackles learning causal structure under uncertainty by proposing BCNP, a Bayesian meta-learning model that maps observational data to a posterior over DAGs and directly samples from this posterior. It uses a transformer-based encoder to extract a permutation-aware representation and a decoder that samples DAGs through a learned distribution over permutations and a lower-triangular edge mask, ensuring acyclicity by construction. The training objective minimizes the KL divergence between the true Bayesian causal posterior and the model’s posterior, effectively marginalizing over functional relationships and enabling efficient posterior sampling. Empirical results show BCNP achieves competitive or superior performance to explicit Bayesian methods and existing meta-learning approaches across synthetic, semi-synthetic, and Syntren data, demonstrating the practicality and robustness of Bayesian meta-learning for causal discovery.

Abstract

Discovering a unique causal structure is difficult due to both inherent identifiability issues, and the consequences of finite data. As such, uncertainty over causal structures, such as those obtained from a Bayesian posterior, are often necessary for downstream tasks. Finding an accurate approximation to this posterior is challenging, due to the large number of possible causal graphs, as well as the difficulty in the subproblem of finding posteriors over the functional relationships of the causal edges. Recent works have used meta-learning to view the problem of estimating the maximum a-posteriori causal graph as supervised learning. Yet, these methods are limited when estimating the full posterior as they fail to encode key properties of the posterior, such as correlation between edges and permutation equivariance with respect to nodes. Further, these methods also cannot reliably sample from the posterior over causal structures. To address these limitations, we propose a Bayesian meta learning model that allows for sampling causal structures from the posterior and encodes these key properties. We compare our meta-Bayesian causal discovery against existing Bayesian causal discovery methods, demonstrating the advantages of directly learning a posterior over causal structure.

A Meta-Learning Approach to Bayesian Causal Discovery

TL;DR

The paper tackles learning causal structure under uncertainty by proposing BCNP, a Bayesian meta-learning model that maps observational data to a posterior over DAGs and directly samples from this posterior. It uses a transformer-based encoder to extract a permutation-aware representation and a decoder that samples DAGs through a learned distribution over permutations and a lower-triangular edge mask, ensuring acyclicity by construction. The training objective minimizes the KL divergence between the true Bayesian causal posterior and the model’s posterior, effectively marginalizing over functional relationships and enabling efficient posterior sampling. Empirical results show BCNP achieves competitive or superior performance to explicit Bayesian methods and existing meta-learning approaches across synthetic, semi-synthetic, and Syntren data, demonstrating the practicality and robustness of Bayesian meta-learning for causal discovery.

Abstract

Discovering a unique causal structure is difficult due to both inherent identifiability issues, and the consequences of finite data. As such, uncertainty over causal structures, such as those obtained from a Bayesian posterior, are often necessary for downstream tasks. Finding an accurate approximation to this posterior is challenging, due to the large number of possible causal graphs, as well as the difficulty in the subproblem of finding posteriors over the functional relationships of the causal edges. Recent works have used meta-learning to view the problem of estimating the maximum a-posteriori causal graph as supervised learning. Yet, these methods are limited when estimating the full posterior as they fail to encode key properties of the posterior, such as correlation between edges and permutation equivariance with respect to nodes. Further, these methods also cannot reliably sample from the posterior over causal structures. To address these limitations, we propose a Bayesian meta learning model that allows for sampling causal structures from the posterior and encodes these key properties. We compare our meta-Bayesian causal discovery against existing Bayesian causal discovery methods, demonstrating the advantages of directly learning a posterior over causal structure.

Paper Structure

This paper contains 36 sections, 22 equations, 2 figures, 24 tables.

Figures (2)

  • Figure 1: Each dataset contains $D$ nodes and $N$ samples where each data point is embedded into a vector of size $H$, giving a $D \times N \times H$ tensor. A query vector of zeros is then appended along the sample axis. The data is passed through $L$ transformer layers which alternate between attention over samples and attention over nodes. The summary representation $R^0$ is constructed using an attention layer where the samples of each node serve as the keys and values and the query vector acting as the query.
  • Figure 2: Computational graph of the decoder described in \ref{['sec:decoder']}. The decoder takes in the summary representation from the encoder $R^0$ as input. $T$ denotes a transformer layer, $\operatorname{MHPA}$ denote multi headed parameter attention (\ref{['eq:multihead_parameter_attention']}), and $\mathcal{GS}$ is the Gumbel-Sinkhorn distribution mena2018learning. The network outputs samples of permutation matrices $Q_s$ and lower triangular binary matrices $A_s$ that can be used to construct samples of DAGs $G_s$.