Table of Contents
Fetching ...

Towards Learning and Explaining Indirect Causal Effects in Neural Networks

Abbavaram Gowtham Reddy, Saketh Bachu, Harsharaj Pathak, Benin L Godfrey, Vineeth N. Balasubramanian, Varshaneya V, Satya Narayanan Kar

TL;DR

The paper tackles the limitation that standard neural networks typically capture only direct causal effects between inputs and outputs, neglecting indirect pathways. It introduces AHCE, an ante-hoc framework that augments NNs with lateral input connections to learn and preserve direct, indirect, and total causal effects during training. The approach formalizes $ACE^{\hat{Y}}_{X_i}$, $ADCE^{\hat{Y}}_{X_i}$, and $AICE^{\hat{Y}}_{X_i}$, and employs a two-phase training regime for an augmented network $\mathcal{N}^{Ind}$, along with a second-order Taylor expansion to estimate interventional expectations and a binning-based strategy for scalable computation. Experiments on synthetic and real-world datasets show that AHCE better approximates ground-truth causal effects than baselines, while efficiency techniques enable application to high-dimensional data. Overall, AHCE provides a principled, scalable path to learning and explaining indirect causal effects in neural models with significant implications for reliability and fairness in safety-critical AI systems.

Abstract

Recently, there has been a growing interest in learning and explaining causal effects within Neural Network (NN) models. By virtue of NN architectures, previous approaches consider only direct and total causal effects assuming independence among input variables. We view an NN as a structural causal model (SCM) and extend our focus to include indirect causal effects by introducing feedforward connections among input neurons. We propose an ante-hoc method that captures and maintains direct, indirect, and total causal effects during NN model training. We also propose an algorithm for quantifying learned causal effects in an NN model and efficient approximation strategies for quantifying causal effects in high-dimensional data. Extensive experiments conducted on synthetic and real-world datasets demonstrate that the causal effects learned by our ante-hoc method better approximate the ground truth effects compared to existing methods.

Towards Learning and Explaining Indirect Causal Effects in Neural Networks

TL;DR

The paper tackles the limitation that standard neural networks typically capture only direct causal effects between inputs and outputs, neglecting indirect pathways. It introduces AHCE, an ante-hoc framework that augments NNs with lateral input connections to learn and preserve direct, indirect, and total causal effects during training. The approach formalizes , , and , and employs a two-phase training regime for an augmented network , along with a second-order Taylor expansion to estimate interventional expectations and a binning-based strategy for scalable computation. Experiments on synthetic and real-world datasets show that AHCE better approximates ground-truth causal effects than baselines, while efficiency techniques enable application to high-dimensional data. Overall, AHCE provides a principled, scalable path to learning and explaining indirect causal effects in neural models with significant implications for reliability and fairness in safety-critical AI systems.

Abstract

Recently, there has been a growing interest in learning and explaining causal effects within Neural Network (NN) models. By virtue of NN architectures, previous approaches consider only direct and total causal effects assuming independence among input variables. We view an NN as a structural causal model (SCM) and extend our focus to include indirect causal effects by introducing feedforward connections among input neurons. We propose an ante-hoc method that captures and maintains direct, indirect, and total causal effects during NN model training. We also propose an algorithm for quantifying learned causal effects in an NN model and efficient approximation strategies for quantifying causal effects in high-dimensional data. Extensive experiments conducted on synthetic and real-world datasets demonstrate that the causal effects learned by our ante-hoc method better approximate the ground truth effects compared to existing methods.
Paper Structure (18 sections, 13 equations, 8 figures, 9 tables, 3 algorithms)

This paper contains 18 sections, 13 equations, 8 figures, 9 tables, 3 algorithms.

Figures (8)

  • Figure 1: (a) A marginalized NN whose inputs $S, E, R$ are not causally related. (b) A marginalized NN whose inputs are connected through feedforward connections (e.g., $S \rightarrow E$) to capture underlying causal relationships (e.g., $S$ causes $E$) to learn the indirect causal effects of inputs on output (e.g. effect of $S$ on $I$ via $E$).
  • Figure 2: Comparison of the proposed architecture $\mathcal{N}^{Ind}$ with a traditional NN architecture $\mathcal{N}$. $\mathcal{G}$ is the ground truth causal graph. $\mathcal{N}$ and $\mathcal{N}^{Ind}$ differ in input layer such that the inputs in $\mathcal{N}^{Ind}$ are connected (shown in blue color) according to the causal edges in $\mathcal{G}$. In contrast, the inputs in $\mathcal{N}$ are independent. $\mathcal{N}$ and $\mathcal{N}^{Ind}$ may contain edges that are not present in $\mathcal{G}$ due to the feedforward connections from input layer to predictions in NN architecture (e.g., $X_1\rightarrow \hat{Y}$ is present in $\mathcal{N}$, $\mathcal{N}^{Ind}$ but not in $\mathcal{G}$).
  • Figure 3: Synthetic DAG
  • Figure 4: (i) Estimated causal graph of Auto-MPG dataset where C: number of cylinders, D: displacement, W: weight, H: horsepower, A: acceleration, and M: miles per gallon. (ii) True causal graph of lung cancer dataset where A: visit to Asia, T: tuberculosis, S: smoking, L: lung cancer, B: bronchitis, E: either T or L, X: X-ray, and D: dyspnea. (iii) True causal graph of Sachs dataset.
  • Figure 5: Convergence of RMSE/Accuracy values at the end of two phases of training outlined in Algorithm \ref{['algo:nn_edges_input_layer']}.
  • ...and 3 more figures

Theorems & Definitions (8)

  • Definition 3.1
  • Definition 3.2
  • Definition 3.3
  • Definition 4.1
  • Definition A.1
  • Definition A.2
  • Definition A.3
  • Definition A.4