Table of Contents
Fetching ...

Sample, estimate, aggregate: A recipe for causal discovery foundation models

Menghua Wu, Yujia Bao, Regina Barzilay, Tommi Jaakkola

TL;DR

This work introduces Sea, a foundation-model-inspired framework for causal discovery that predicts a global graph $G=(V,E)$ from a mix of subset-level marginal estimates and global statistics, trained on large-scale synthetic data to generalize across misspecifications and data regimes. The method leverages axial-attention networks to reconcile partial, local causal information into a coherent global structure, with theoretical results supporting its well-specified capacity and skeleton/v-structure recovery capabilities. Empirically, Sea achieves state-of-the-art performance on synthetic and real datasets, scales to graphs with hundreds of nodes, and enables rapid inference with zero or minimal finetuning when adapting to new data assumptions. The work demonstrates robust generalization, adaptability, and efficiency, while outlining practical considerations and avenues for extending causal-discovery foundation models to more diverse assumptions and domains.

Abstract

Causal discovery, the task of inferring causal structure from data, has the potential to uncover mechanistic insights from biological experiments, especially those involving perturbations. However, causal discovery algorithms over larger sets of variables tend to be brittle against misspecification or when data are limited. For example, single-cell transcriptomics measures thousands of genes, but the nature of their relationships is not known, and there may be as few as tens of cells per intervention setting. To mitigate these challenges, we propose a foundation model-inspired approach: a supervised model trained on large-scale, synthetic data to predict causal graphs from summary statistics -- like the outputs of classical causal discovery algorithms run over subsets of variables and other statistical hints like inverse covariance. Our approach is enabled by the observation that typical errors in the outputs of a discovery algorithm remain comparable across datasets. Theoretically, we show that the model architecture is well-specified, in the sense that it can recover a causal graph consistent with graphs over subsets. Empirically, we train the model to be robust to misspecification and distribution shift using diverse datasets. Experiments on biological and synthetic data confirm that this model generalizes well beyond its training set, runs on graphs with hundreds of variables in seconds, and can be easily adapted to different underlying data assumptions.

Sample, estimate, aggregate: A recipe for causal discovery foundation models

TL;DR

This work introduces Sea, a foundation-model-inspired framework for causal discovery that predicts a global graph from a mix of subset-level marginal estimates and global statistics, trained on large-scale synthetic data to generalize across misspecifications and data regimes. The method leverages axial-attention networks to reconcile partial, local causal information into a coherent global structure, with theoretical results supporting its well-specified capacity and skeleton/v-structure recovery capabilities. Empirically, Sea achieves state-of-the-art performance on synthetic and real datasets, scales to graphs with hundreds of nodes, and enables rapid inference with zero or minimal finetuning when adapting to new data assumptions. The work demonstrates robust generalization, adaptability, and efficiency, while outlining practical considerations and avenues for extending causal-discovery foundation models to more diverse assumptions and domains.

Abstract

Causal discovery, the task of inferring causal structure from data, has the potential to uncover mechanistic insights from biological experiments, especially those involving perturbations. However, causal discovery algorithms over larger sets of variables tend to be brittle against misspecification or when data are limited. For example, single-cell transcriptomics measures thousands of genes, but the nature of their relationships is not known, and there may be as few as tens of cells per intervention setting. To mitigate these challenges, we propose a foundation model-inspired approach: a supervised model trained on large-scale, synthetic data to predict causal graphs from summary statistics -- like the outputs of classical causal discovery algorithms run over subsets of variables and other statistical hints like inverse covariance. Our approach is enabled by the observation that typical errors in the outputs of a discovery algorithm remain comparable across datasets. Theoretically, we show that the model architecture is well-specified, in the sense that it can recover a causal graph consistent with graphs over subsets. Empirically, we train the model to be robust to misspecification and distribution shift using diverse datasets. Experiments on biological and synthetic data confirm that this model generalizes well beyond its training set, runs on graphs with hundreds of variables in seconds, and can be easily adapted to different underlying data assumptions.
Paper Structure (55 sections, 8 theorems, 40 equations, 12 figures, 22 tables, 1 algorithm)

This paper contains 55 sections, 8 theorems, 40 equations, 12 figures, 22 tables, 1 algorithm.

Key Result

Theorem 3.1

Let $G=(V,E)$ be a directed acyclic graph with maximum degree $d$. For $S\subseteq V$, let $E'_S$ denote the marginal estimate over $S$. Let $\mathcal{S}_d$ denote the superset that contains all subsets $S\subseteq V$ of size at most $d$. Given $\{ E'_S \}_{S \in \mathcal{S}_{d+2}}$, a stack of $L$

Figures (12)

  • Figure 1: Overview of our goals and approach. (A) Criteria we aim to fulfill. (B-C) Inference and training procedure. Green: raw data. Blue: graph / features. Yellow: Learned. Gray: Stochastic, but not learned.
  • Figure 2: Aggregator architecture. Marginal graph estimates and global statistics are embedded into the model dimension. 1D positional embeddings are added along both rows and columns. Embedded features pass through a series of axial attention blocks, which attend to the marginal and global features. Final layer global features pass through a feedforward network to predict the causal graph.
  • Figure 3: Few-shot learning behavior emerges as training set increases. "Tiny" Sea trained on 1/4 of the data is comparable to the full model on $N=10$ datasets when given $T=50$ batches, but is less robust with only $T=10$.
  • Figure 4: Ablations with Sea (Gies) for estimation parameters on $N=100, E=100$. Error bars indicate 95% confidence interval across the 5 i.i.d. datasets of each setting. All parameters are set to the defaults (Section \ref{['subsec:model']}) unless otherwise noted. (A) Dashed: inverse covariance at $M=500$. (C) Variance is unusually high for Sigmoid $b=300$ until $T=100$, indicating that larger batches result in more stable results.
  • Figure 5: Resolving marginal graphs. Subsets of nodes revealed to the PC algorithm (circled in row 1) and its outputs (row 2).
  • ...and 7 more figures

Theorems & Definitions (21)

  • Theorem 3.1
  • Definition A.1
  • Definition A.2: Theorem 3.4 from causation-prediction-search
  • Theorem A.3: Theorem 5.1 from causation-prediction-search
  • Definition A.4: Marginal estimate
  • Proposition A.5
  • Remark A.6
  • Lemma A.7
  • proof
  • Lemma A.8
  • ...and 11 more