CaTs and DAGs: Integrating Directed Acyclic Graphs with Transformers for Causally Constrained Predictions
Matthew J. Vowels, Mathieu Rochat, Sina Akbari
TL;DR
The paper addresses the brittleness and opacity of standard neural networks in causal settings by introducing Causal Transformers (CaTs) and Causal Fully Connected Networks (CFCNs) that enforce DAG-based causal constraints via a novel masking mechanism. CaT employs causally masked cross-attention, while CFCN uses MADE-like masks, both ensuring that each node’s prediction depends only on its causal parents, enabling honest interventions and aligning with the g-formula. Theoretical results establish that these architectures realize a Causal Bayesian Network with a DAG-consistent factorization and invariance properties under covariate shift, while experiments across simulated, benchmark, and real-world data show competitive causal-effect estimation and robustness, even without task-specific hyperparameter tuning. The work offers a general, DAG-guided modeling framework that enhances robustness, interpretability, and transportability of neural models in domains where causal structure is essential.
Abstract
Artificial Neural Networks (ANNs), including fully-connected networks and transformers, are highly flexible and powerful function approximators, widely applied in fields like computer vision and natural language processing. However, their inability to inherently respect causal structures can limit their robustness, making them vulnerable to covariate shift and difficult to interpret/explain. This poses significant challenges for their reliability in real-world applications. In this paper, we introduce Causal Transformers (CaTs), a general model class designed to operate under predefined causal constraints, as specified by a Directed Acyclic Graph (DAG). CaTs retain the powerful function approximation abilities of traditional neural networks while adhering to the underlying structural constraints, improving robustness, reliability, and interpretability at inference time. This approach opens new avenues for deploying neural networks in more demanding, real-world scenarios where robustness and explainability is critical.
