Table of Contents
Fetching ...

Enhancing the Performance of Neural Networks Through Causal Discovery and Integration of Domain Knowledge

Xiaoge Zhang, Xiao-Lin Wang, Fenglei Fan, Yiu-Ming Cheung, Indranil Bose

TL;DR

This work tackles spurious correlations and generalization gaps in deep learning by introducing CINN, a framework that explicitly encodes hierarchical causal DAGs into neural network architectures while preserving edge orientations. The methodology proceeds in three steps: causal discovery from observational data via a continuous optimization formulation with an acyclicity constraint, encoding the resulting DAG into a layered CINN where root, intermediate, and leaf nodes map to network components, and a multi-task loss with PCGrad to harmonize learning across causal groups. The authors demonstrate substantial predictive gains and reduced variance across five UCI regression datasets, aided by ablation studies and robustness checks, and show that incorporating domain knowledge further improves performance and stability. The work provides a generic, human-in-the-loop interface to fuse data-driven causal discovery with expert priors, advancing trustworthy, interpretable, and intervention-capable neural models for practical applications.

Abstract

In this paper, we develop a generic methodology to encode hierarchical causality structure among observed variables into a neural network in order to improve its predictive performance. The proposed methodology, called causality-informed neural network (CINN), leverages three coherent steps to systematically map the structural causal knowledge into the layer-to-layer design of neural network while strictly preserving the orientation of every causal relationship. In the first step, CINN discovers causal relationships from observational data via directed acyclic graph (DAG) learning, where causal discovery is recast as a continuous optimization problem to avoid the combinatorial nature. In the second step, the discovered hierarchical causality structure among observed variables is systematically encoded into neural network through a dedicated architecture and customized loss function. By categorizing variables in the causal DAG as root, intermediate, and leaf nodes, the hierarchical causal DAG is translated into CINN with a one-to-one correspondence between nodes in the causal DAG and units in the CINN while maintaining the relative order among these nodes. Regarding the loss function, both intermediate and leaf nodes in the DAG graph are treated as target outputs during CINN training so as to drive co-learning of causal relationships among different types of nodes. As multiple loss components emerge in CINN, we leverage the projection of conflicting gradients to mitigate gradient interference among the multiple learning tasks. Computational experiments across a broad spectrum of UCI data sets demonstrate substantial advantages of CINN in predictive performance over other state-of-the-art methods. In addition, an ablation study underscores the value of integrating structural and quantitative causal knowledge in enhancing the neural network's predictive performance incrementally.

Enhancing the Performance of Neural Networks Through Causal Discovery and Integration of Domain Knowledge

TL;DR

This work tackles spurious correlations and generalization gaps in deep learning by introducing CINN, a framework that explicitly encodes hierarchical causal DAGs into neural network architectures while preserving edge orientations. The methodology proceeds in three steps: causal discovery from observational data via a continuous optimization formulation with an acyclicity constraint, encoding the resulting DAG into a layered CINN where root, intermediate, and leaf nodes map to network components, and a multi-task loss with PCGrad to harmonize learning across causal groups. The authors demonstrate substantial predictive gains and reduced variance across five UCI regression datasets, aided by ablation studies and robustness checks, and show that incorporating domain knowledge further improves performance and stability. The work provides a generic, human-in-the-loop interface to fuse data-driven causal discovery with expert priors, advancing trustworthy, interpretable, and intervention-capable neural models for practical applications.

Abstract

In this paper, we develop a generic methodology to encode hierarchical causality structure among observed variables into a neural network in order to improve its predictive performance. The proposed methodology, called causality-informed neural network (CINN), leverages three coherent steps to systematically map the structural causal knowledge into the layer-to-layer design of neural network while strictly preserving the orientation of every causal relationship. In the first step, CINN discovers causal relationships from observational data via directed acyclic graph (DAG) learning, where causal discovery is recast as a continuous optimization problem to avoid the combinatorial nature. In the second step, the discovered hierarchical causality structure among observed variables is systematically encoded into neural network through a dedicated architecture and customized loss function. By categorizing variables in the causal DAG as root, intermediate, and leaf nodes, the hierarchical causal DAG is translated into CINN with a one-to-one correspondence between nodes in the causal DAG and units in the CINN while maintaining the relative order among these nodes. Regarding the loss function, both intermediate and leaf nodes in the DAG graph are treated as target outputs during CINN training so as to drive co-learning of causal relationships among different types of nodes. As multiple loss components emerge in CINN, we leverage the projection of conflicting gradients to mitigate gradient interference among the multiple learning tasks. Computational experiments across a broad spectrum of UCI data sets demonstrate substantial advantages of CINN in predictive performance over other state-of-the-art methods. In addition, an ablation study underscores the value of integrating structural and quantitative causal knowledge in enhancing the neural network's predictive performance incrementally.
Paper Structure (27 sections, 8 equations, 10 figures, 5 tables, 1 algorithm)

This paper contains 27 sections, 8 equations, 10 figures, 5 tables, 1 algorithm.

Figures (10)

  • Figure 1: Flowchart of the developed methodology
  • Figure 2: Demonstration of the orientation of causal relationships and node categorization. Nodes in identical color belong to the same group. Specifically, nodes in green are categorized as isolated nodes, nodes in yellow are root nodes, nodes in blue are intermediate nodes, and nodes in purple indicate leaf nodes.
  • Figure 3: Proposed CINN architecture. Nodes in yellow, blue, and purple represent the sets of root nodes $\mathsf{V}_C$, intermediate nodes $\mathsf{V}_B$, and output nodes $\mathsf{V}_O$, respectively. Note that $\mathsf{V}_B$ might have multiple layers of nodes stacked together with each layer having its own set of features.
  • Figure 4: Demonstration of PCGrad. (a) There is no conflict between $\bm{\Delta}_{\mathcal{L}_{MSE}^B}$ and $\bm{\Delta}_{\mathcal{L}_{R}}$. (b) There is a high conflict between $\bm{\Delta}_{\mathcal{L}_{MSE}^B}$ and $\bm{\Delta}_{\mathcal{L}_{R}}$. (c) PCGrad projects the gradient $\bm{\Delta}_{\mathcal{L}_{MSE}^B}$ onto the norm vector of the gradient $\bm{\Delta}_{\mathcal{L}_{\mathcal{R}}}$. (d) PCGrad projects the gradient $\bm{\Delta}_{\mathcal{L}_{\mathcal{R}}}$ onto the norm vector of the gradient $\bm{\Delta}_{\mathcal{L}_{MSE}^B}$.
  • Figure 6: Discovery and refinement of causal relationships among observed variables for the BH dataset. The bold-faced number in each circle indicates the feature ID, while the associated text represents the feature name. Our primary task is to predict the MEDV value (node 13) using other relevant features. The circles in the same color belong to the same category of nodes. Specifically, circles in yellow refer to the set of root nodes $\mathsf{V}_C = \{ {\mathsf{X}_{0}}, {\mathsf{X}_{3}}, {\mathsf{X}_{4}}, {\mathsf{X}_{5}}, {\mathsf{X}_{7}}, {\mathsf{X}_{8}}, {\mathsf{X}_{10}}, {\mathsf{X}_{12}} \}$, circles in blue indicate the set of intermediate nodes $\mathsf{V}_B = \{ {\mathsf{X}_{1}}, {\mathsf{X}_{2}}, {\mathsf{X}_{13}} \}$, while circles in purple are the set of leaf nodes $\mathsf{V}_O = \{ {\mathsf{X}_{6}}, {\mathsf{X}_{9}}, {\mathsf{X}_{11}}\}$. The red dotted lines denote the edges eliminated from the discovered causal graph by exploiting expert knowledge, while the edge in blue denotes the link that is additionally added to the discovered causal graph.
  • ...and 5 more figures