Table of Contents
Fetching ...

Use What You Know: Causal Foundation Models with Partial Graphs

Arik Reuter, Anish Dhir, Cristiana Diaconu, Jake Robertson, Ole Ossen, Frank Hutter, Adrian Weller, Mark van der Wilk, Bernhard Schölkopf

TL;DR

This work tackles the challenge of incorporating domain knowledge into Causal Foundation Models (CFMs) to improve causal effect estimation. It proposes conditioning CFMs on partial ancestral information via partially known ancestor matrices (PAMs) and presents architectures that combine attention biasing with graph convolution to leverage this structure. Empirical results show that partial graph conditioning significantly improves predictions, with soft attention biasing often outperforming other strategies and enabling a single CFM to approach the performance of specialized, graph-specific models. The findings advance the goal of all-in-one CFMs capable of answering causal queries in a data-driven, knowledge-informed manner, with practical benefits for complex and semi-synthetic benchmarks.

Abstract

Estimating causal quantities traditionally relies on bespoke estimators tailored to specific assumptions. Recently proposed Causal Foundation Models (CFMs) promise a more unified approach by amortising causal discovery and inference in a single step. However, in their current state, they do not allow for the incorporation of any domain knowledge, which can lead to suboptimal predictions. We bridge this gap by introducing methods to condition CFMs on causal information, such as the causal graph or more readily available ancestral information. When access to complete causal graph information is too strict a requirement, our approach also effectively leverages partial causal information. We systematically evaluate conditioning strategies and find that injecting learnable biases into the attention mechanism is the most effective method to utilise full and partial causal information. Our experiments show that this conditioning allows a general-purpose CFM to match the performance of specialised models trained on specific causal structures. Overall, our approach addresses a central hurdle on the path towards all-in-one causal foundation models: the capability to answer causal queries in a data-driven manner while effectively leveraging any amount of domain expertise.

Use What You Know: Causal Foundation Models with Partial Graphs

TL;DR

This work tackles the challenge of incorporating domain knowledge into Causal Foundation Models (CFMs) to improve causal effect estimation. It proposes conditioning CFMs on partial ancestral information via partially known ancestor matrices (PAMs) and presents architectures that combine attention biasing with graph convolution to leverage this structure. Empirical results show that partial graph conditioning significantly improves predictions, with soft attention biasing often outperforming other strategies and enabling a single CFM to approach the performance of specialized, graph-specific models. The findings advance the goal of all-in-one CFMs capable of answering causal queries in a data-driven, knowledge-informed manner, with practical benefits for complex and semi-synthetic benchmarks.

Abstract

Estimating causal quantities traditionally relies on bespoke estimators tailored to specific assumptions. Recently proposed Causal Foundation Models (CFMs) promise a more unified approach by amortising causal discovery and inference in a single step. However, in their current state, they do not allow for the incorporation of any domain knowledge, which can lead to suboptimal predictions. We bridge this gap by introducing methods to condition CFMs on causal information, such as the causal graph or more readily available ancestral information. When access to complete causal graph information is too strict a requirement, our approach also effectively leverages partial causal information. We systematically evaluate conditioning strategies and find that injecting learnable biases into the attention mechanism is the most effective method to utilise full and partial causal information. Our experiments show that this conditioning allows a general-purpose CFM to match the performance of specialised models trained on specific causal structures. Overall, our approach addresses a central hurdle on the path towards all-in-one causal foundation models: the capability to answer causal queries in a data-driven manner while effectively leveraging any amount of domain expertise.
Paper Structure (63 sections, 1 theorem, 18 equations, 13 figures, 2 tables)

This paper contains 63 sections, 1 theorem, 18 equations, 13 figures, 2 tables.

Key Result

Proposition A.1

Assume a causally sufficient SCM on the observed variables $\mathcal{N}=\{t,y,x_1,\dots,x_K\}$ with potential outcomes $y_0,y_1$, a binary treatment $t \in \{0,1\}$, and Let $\tilde{\mathbf{T}}$ be a PAM such that: Then the unconfoundedness condition holds:

Figures (13)

  • Figure 1: Causal Foundation Models based on PFNs for predicting the effect of causal interventions: A model takes as input observational data $\mathcal{D}$ and a query-point comprising an intervention $do(t)$ together with a feature vector $\bm{x}$. The model, implemented as a transformer, outputs the posterior of the causal effect $p(y|\text{do}(t),\bm{x}, \mathcal{D})$. This can be seen as the model marginalising over causal structures (SCMs) $\psi$, based on the posterior probability that $\psi$ could have generated the observational data $\mathcal{D}$. (a) The model can only leverage observational data to compute its posterior belief $p(\psi|\mathcal{D})$ over possible causal structures (visualized as the region shaded in blue). When using this posterior to compute the posterior predictive ${p(y|\text{do}(t),\bm{x}, \mathcal{D}) = \int p(y|t, \bm{x}, \psi_{do(t)})p(\psi|\mathcal{D})d\psi}$, this leads to high uncertainty, and possibly imprecise predictions. (b) In contrast, providing the model with partial graph information $\widetilde{\mathcal{G}}$ allows us to substantially narrow down the set of causal structures that have high posterior mass $p(\psi|\mathcal{D}, \widetilde{\mathcal{G}})$. This yields a more concentrated posterior predictive $p(y|\text{do}(t),\bm{x}, \mathcal{D}, \widetilde{\mathcal{G}})$, and thus more accurate predictions of the outcome.
  • Figure 2: Posterior predictive samples in the two-node case where $T$ causes $Y$ ($T \rightarrow Y$); more specifically treatment and outcome are related via $Y = 0.25\cdot T + \epsilon$, for $\epsilon \sim \mathcal{N}(0, 0.1)$. Here, the causal direction is not identifiable from observational data. Left: Without telling the model $Q_\theta$ that is trained on arbitrary causal structures (as, e.g., in robertson2025pfn or dhirestimating) the correct causal direction, it outputs a mixture distribution between the posteriors for the two causal directions $Q_\theta(Y|\text{do}(T)) = 0.5 \cdot P_{}(Y|\text{do}(T), Y \rightarrow T) + 0.5 \cdot P_{}(Y|\text{do}(T), T \rightarrow Y)$. Right: When conditioned on the right causal direction, however, the model's output aligns with the correct causal effect: $Q_\theta(Y|\text{do}(T), T \rightarrow Y) = P_{}(Y|\text{do}(T), T \rightarrow Y)$, leading to more precise predicitons.
  • Figure 3: Performance comparison of different graph-conditioning methods on linear-Gaussian data. Results are shown as the improvement in in negative log-likelihood (NLL), in terms of mean-squared-error (MSE) and coefficient of determination ($R^2$) on held-out test data relative to a non-graph-conditioned baseline. The error-bars represent $95$-percent bootstrap confidence intervals for the median. We compare soft and hard attention biases as well as a GCN-based conditioning approach, using either ground-truth adjacency matrices (Adj.) or ancestor matrices (Anc.) as graph input. All methods yield significant benefits with Soft Attention based methods performing the best.
  • Figure 4: Performance on complex synthetic data for different fractions of hidden entries in the ancestor matrix for the Soft Attention (Ancestor) and GCN + Soft Attention (Ancestor) graph-conditioning approaches. Results are measured as the absolute difference to a baseline that does not support graph conditioning but is otherwise trained identically. The error-bars represent 95-percent bootstrap confidence intervals for the mean. Providing full graph information significantly improves performance while GCN + Soft Attention performs better than just using Soft Attention.
  • Figure 5: Predictive performance on small-scale datasets of a PFN trained on our causal prior compared to standard baselines. We subsample datasets with up to 1000 training points from the TabArena regression benchmark and consider out-of-the box performance, i.e. default hyperparameters. A PFN trained on our prior achieves strong performance compared to the baselines, but is outperformed by TabPFN v2.5.
  • ...and 8 more figures

Theorems & Definitions (2)

  • Proposition A.1: Unconfoundedness from partial graph knowledge
  • proof : Proof sketch