CauScale: Neural Causal Discovery at Scale
Bo Peng, Sirui Chen, Jiaguo Tian, Yu Qiao, Chaochao Lu
TL;DR
CauScale tackles the scalability bottlenecks in causal discovery by presenting a neural, amortized approach with a two-stream architecture that couples data-driven relational evidence with graph priors. It introduces a reduction unit and tied attention to dramatically cut memory and compute, and a DataGraph block to preserve essential signals despite compression. Across synthetic and semi-synthetic gene networks, CauScale achieves near-perfect in-distribution accuracy ($mAP$ up to $99.6\%$) and strong generalization to larger graphs and OOD mechanisms, while delivering massive speedups (up to $13{,}000\times$) over prior methods. The work demonstrates practical scalability to graphs with up to $1000$ nodes and highlights its potential as a pre-training direction for efficient causal-discovery models at scale.
Abstract
Causal discovery is essential for advancing data-driven fields such as scientific AI and data analysis, yet existing approaches face significant time- and space-efficiency bottlenecks when scaling to large graphs. To address this challenge, we present CauScale, a neural architecture designed for efficient causal discovery that scales inference to graphs with up to 1000 nodes. CauScale improves time efficiency via a reduction unit that compresses data embeddings and improves space efficiency by adopting tied attention weights to avoid maintaining axis-specific attention maps. To keep high causal discovery accuracy, CauScale adopts a two-stream design: a data stream extracts relational evidence from high-dimensional observations, while a graph stream integrates statistical graph priors and preserves key structural signals. CauScale successfully scales to 500-node graphs during training, where prior work fails due to space limitations. Across testing data with varying graph scales and causal mechanisms, CauScale achieves 99.6% mAP on in-distribution data and 84.4% on out-of-distribution data, while delivering 4-13,000 times inference speedups over prior methods. Our project page is at https://github.com/OpenCausaLab/CauScale.
