Table of Contents
Fetching ...

CaTs and DAGs: Integrating Directed Acyclic Graphs with Transformers for Causally Constrained Predictions

Matthew J. Vowels, Mathieu Rochat, Sina Akbari

TL;DR

The paper addresses the brittleness and opacity of standard neural networks in causal settings by introducing Causal Transformers (CaTs) and Causal Fully Connected Networks (CFCNs) that enforce DAG-based causal constraints via a novel masking mechanism. CaT employs causally masked cross-attention, while CFCN uses MADE-like masks, both ensuring that each node’s prediction depends only on its causal parents, enabling honest interventions and aligning with the g-formula. Theoretical results establish that these architectures realize a Causal Bayesian Network with a DAG-consistent factorization and invariance properties under covariate shift, while experiments across simulated, benchmark, and real-world data show competitive causal-effect estimation and robustness, even without task-specific hyperparameter tuning. The work offers a general, DAG-guided modeling framework that enhances robustness, interpretability, and transportability of neural models in domains where causal structure is essential.

Abstract

Artificial Neural Networks (ANNs), including fully-connected networks and transformers, are highly flexible and powerful function approximators, widely applied in fields like computer vision and natural language processing. However, their inability to inherently respect causal structures can limit their robustness, making them vulnerable to covariate shift and difficult to interpret/explain. This poses significant challenges for their reliability in real-world applications. In this paper, we introduce Causal Transformers (CaTs), a general model class designed to operate under predefined causal constraints, as specified by a Directed Acyclic Graph (DAG). CaTs retain the powerful function approximation abilities of traditional neural networks while adhering to the underlying structural constraints, improving robustness, reliability, and interpretability at inference time. This approach opens new avenues for deploying neural networks in more demanding, real-world scenarios where robustness and explainability is critical.

CaTs and DAGs: Integrating Directed Acyclic Graphs with Transformers for Causally Constrained Predictions

TL;DR

The paper addresses the brittleness and opacity of standard neural networks in causal settings by introducing Causal Transformers (CaTs) and Causal Fully Connected Networks (CFCNs) that enforce DAG-based causal constraints via a novel masking mechanism. CaT employs causally masked cross-attention, while CFCN uses MADE-like masks, both ensuring that each node’s prediction depends only on its causal parents, enabling honest interventions and aligning with the g-formula. Theoretical results establish that these architectures realize a Causal Bayesian Network with a DAG-consistent factorization and invariance properties under covariate shift, while experiments across simulated, benchmark, and real-world data show competitive causal-effect estimation and robustness, even without task-specific hyperparameter tuning. The work offers a general, DAG-guided modeling framework that enhances robustness, interpretability, and transportability of neural models in domains where causal structure is essential.

Abstract

Artificial Neural Networks (ANNs), including fully-connected networks and transformers, are highly flexible and powerful function approximators, widely applied in fields like computer vision and natural language processing. However, their inability to inherently respect causal structures can limit their robustness, making them vulnerable to covariate shift and difficult to interpret/explain. This poses significant challenges for their reliability in real-world applications. In this paper, we introduce Causal Transformers (CaTs), a general model class designed to operate under predefined causal constraints, as specified by a Directed Acyclic Graph (DAG). CaTs retain the powerful function approximation abilities of traditional neural networks while adhering to the underlying structural constraints, improving robustness, reliability, and interpretability at inference time. This approach opens new avenues for deploying neural networks in more demanding, real-world scenarios where robustness and explainability is critical.

Paper Structure

This paper contains 21 sections, 6 theorems, 33 equations, 8 figures, 8 tables, 1 algorithm.

Key Result

Proposition 1

Under Assumptions asm:mask--asm:tokenwise, for any block $\ell \geq 1$ and node $i \in Z$, there exists a measurable function $F_i^{(\ell)}$ such that Consequently, for any $S \subseteq Z \setminus \mathrm{pa}(i)$,

Figures (8)

  • Figure 1: Illustrating the lack of covariate shift robustness associated with conventional machine learning models like random forests breiman2001, multilayer perceptrons (MLPs), and transformers Vaswani2017, compared with our CaT and CFCN.
  • Figure 2: A top-level depiction of how the masking and routing is applied in CaT to an input with a batch $B$ of $\vert Z \vert = 3$$\times$$C$-dimensional embeddings and a corresponding causal DAG. The input parameter $\boldsymbol{\gamma}$, which is initially random, is recursively fed as an input to the causal Heads, 'extracting' information from the embeddings of the input via the causal cross-attention operation, according to the constraints imposed by the adjacency matrix $\mathbf{A}$. Best viewed in color. See main text for details.
  • Figure 3: Motivating example. The data for the experiment are generated from DAG (a), whilst (b) and (c) represent to misspecified DAGs used for estimation. Dashed curved lines in (b) emphasise possible endogeneity of $D$, $L_1$ and $L_2$, and the absence of independence in typical non-causal machine learning approaches. The thick edges depict the target causal relationship of interest.
  • Figure 4: Simple example sequence of matrix transformations used to generate masks for CFCN in a three variable, three layer case, where the DAG includes $X^1 \rightarrow X^2 \rightarrow X^3$ mediation as well as a direct path $X^1 \rightarrow X^3$, and the number of neurons in each layer is 3, 6, and 3. The transition from Layer 1 to Layer 2 includes the introduction of the diagonal 'pass-through' which is absent in the first layer.
  • Figure 5: Two examples demonstrating how the delayed introduction of the identity diagonal at layer 2 and onwards (a) prevents inputs attending to themselves at the first layer and (b) allows other intermediate predictions of these inputs to be used for the prediction other variables. Without the introduction of identity, signals from e.g. $\mathbf{x}_1$ used for predicting $\mathbf{x}_2$ would be blocked, and $\mathbf{x}_2$ itself could not be predicted. Applies to both CaT and CFCN.
  • ...and 3 more figures

Theorems & Definitions (8)

  • Proposition 1: Parental sufficiency
  • Remark 1: Strict causal dependency through masking
  • Lemma 1: Local Markov property
  • Theorem 1: Structural soundness and DAG factorization
  • Definition 1
  • Proposition 2: Truncated product via CMCA
  • Theorem 2: Structural robustness under covariate shift
  • Corollary 1: Transportability of identified causal queries