Table of Contents
Fetching ...

Demystifying amortized causal discovery with transformers

Francesco Montagna, Max Cairney-Leeming, Dhanya Sridhar, Francesco Locatello

TL;DR

The paper ties identifiability theory to amortized causal discovery by analyzing CSIvA, a transformer-based model trained on synthetic data to infer causal graphs from observational data. It shows that the training-data distribution implicitly imposes a prior on test graphs, and that identifiability governs when the model can reliably recover the true causal structure. Empirically, CSIvA demonstrates strong in-distribution generalization but limited out-of-distribution generalization across unseen mechanism types and noise distributions; training on mixtures of identifiable SCMs markedly improves generalization across a broader class of models. The findings advocate for explicit incorporation of identifiability-inspired priors in learning-based causal discovery and propose mixture training as a practical route to broaden recoverable causal structures, while highlighting the remaining gap between theory and robust OOD performance.

Abstract

Supervised learning approaches for causal discovery from observational data often achieve competitive performance despite seemingly avoiding explicit assumptions that traditional methods make for identifiability. In this work, we investigate CSIvA (Ke et al., 2023), a transformer-based model promising to train on synthetic data and transfer to real data. First, we bridge the gap with existing identifiability theory and show that constraints on the training data distribution implicitly define a prior on the test observations. Consistent with classical approaches, good performance is achieved when we have a good prior on the test data, and the underlying model is identifiable. At the same time, we find new trade-offs. Training on datasets generated from different classes of causal models, unambiguously identifiable in isolation, improves the test generalization. Performance is still guaranteed, as the ambiguous cases resulting from the mixture of identifiable causal models are unlikely to occur (which we formally prove). Overall, our study finds that amortized causal discovery still needs to obey identifiability theory, but it also differs from classical methods in how the assumptions are formulated, trading more reliance on assumptions on the noise type for fewer hypotheses on the mechanisms.

Demystifying amortized causal discovery with transformers

TL;DR

The paper ties identifiability theory to amortized causal discovery by analyzing CSIvA, a transformer-based model trained on synthetic data to infer causal graphs from observational data. It shows that the training-data distribution implicitly imposes a prior on test graphs, and that identifiability governs when the model can reliably recover the true causal structure. Empirically, CSIvA demonstrates strong in-distribution generalization but limited out-of-distribution generalization across unseen mechanism types and noise distributions; training on mixtures of identifiable SCMs markedly improves generalization across a broader class of models. The findings advocate for explicit incorporation of identifiability-inspired priors in learning-based causal discovery and propose mixture training as a practical route to broaden recoverable causal structures, while highlighting the remaining gap between theory and robust OOD performance.

Abstract

Supervised learning approaches for causal discovery from observational data often achieve competitive performance despite seemingly avoiding explicit assumptions that traditional methods make for identifiability. In this work, we investigate CSIvA (Ke et al., 2023), a transformer-based model promising to train on synthetic data and transfer to real data. First, we bridge the gap with existing identifiability theory and show that constraints on the training data distribution implicitly define a prior on the test observations. Consistent with classical approaches, good performance is achieved when we have a good prior on the test data, and the underlying model is identifiable. At the same time, we find new trade-offs. Training on datasets generated from different classes of causal models, unambiguously identifiable in isolation, improves the test generalization. Performance is still guaranteed, as the ambiguous cases resulting from the mixture of identifiable causal models are unlikely to occur (which we formally prove). Overall, our study finds that amortized causal discovery still needs to obey identifiability theory, but it also differs from classical methods in how the assumptions are formulated, trading more reliance on assumptions on the noise type for fewer hypotheses on the mechanisms.
Paper Structure (59 sections, 3 theorems, 35 equations, 16 figures, 1 table)

This paper contains 59 sections, 3 theorems, 35 equations, 16 figures, 1 table.

Key Result

Theorem 1

Let $X$ be generated by a restricted additive noise model with graph $\mathcal{G}$, and assume that the causal mechanisms $f_j$ are not constant in any of the input arguments, i.e. for $X_i \in X_{\operatorname{PA}^\mathcal{G}_j}$, there exist $x_i \neq x'_i$ such that $f_j(x_{\operatorname{PA}^\mat

Figures (16)

  • Figure 1: In-distribution generalization of CSIvA trained and tested on data generated according to the same structural causal models, fixing mechanisms, and noise distributions between training and testing. As baselines for comparison, we use DirectLiNGAM on linear SCMs and NoGAM on nonlinear ANM (we use their https://causal-learn.readthedocs.io/en/latest/index.html and https://www.pywhy.org/dodiscover/dev/generated/dodiscover.toporder.NoGAM.html implementations). CSIvA performance is clearly non-trivial and generalizing well.
  • Figure 2: Out-of-distribution generalisation. We train three CSIvA models on data sampled from SCMs with linear, nonlinear additive, and post-nonlinear mechanisms; and fixed mlp noise distribution. In Figure \ref{['fig:ood-mechanisms']} we test across different mechanism types, with mlp-distributed noise terms both in test and training. In Figure \ref{['fig:ood-noise']} we test across different noise distributions, with test mechanism types fixed from training. CSIvA struggles to generalize to unseen causal mechanisms and often displays degraded performance over new noise distributions.
  • Figure 3: Experiments on identifiability theory. Figure \ref{['fig:invertible']} shows the SHD of models trained on different ratios of linear and nonlinear invertible data of Example \ref{['example:invertible_gumbel']}. In Figure \ref{['fig:linear_gauss']} we test the performance on linear-Gaussian data. Models are trained with different ratios of samples from linear and nonlinear SCMs with Gaussian noise terms. The validation results showcase that the networks were trained successfully. In both cases CSIvA behaves according to identifiability theory, failing to predict on invertible data (50:50 ratio) and linear Gaussian models.
  • Figure 4: Mixture of causal mechanisms. We train four models on samples from structural casual models with different mechanism types. We compare their test SHD (the lower, the better) against networks trained on datasets generated according to a single type of mechanism. The dashed line indicates the test SHD of a model trained on samples with the same mechanisms as test SCM. Training on multiple causal models with different mechanisms (mixed bars) always improves performance compared to training on single SCMs.
  • Figure 5: Mixture of noise distributions. We train three networks on samples from SCMs with different noise terms distributions and fixed mechanism types: linear, nonlinear, and post-nonlinear. We present their test SHD (the lower, the better) on data from SCMs with the mechanisms fixed with respect to training, and noise terms changing between each dataset. Training on multiple causal models with different noises (all distributions bars) always improves performance compared to training on single SCMs with fixed mlp noise (only mlp bars).
  • ...and 11 more figures

Theorems & Definitions (9)

  • Definition 1: Identifiable causal model
  • Example 1
  • Example 2
  • Definition 2: Definition 27 of peters_2014_identifiability
  • Theorem 1: Theorem 28 of peters_2014_identifiability
  • Proposition 1: Corollary of Theorem 1 of hoyer08_anm
  • Theorem 2: Theorem 1 of zhang2009pnl
  • proof : Proof of Theorem \ref{['thm:zhang09']}
  • proof