Table of Contents
Fetching ...

Causal Neighbourhood Learning for Invariant Graph Representations

Simi Job, Xiaohui Tao, Taotao Cai, Haoran Xie, Jianming Yong

TL;DR

This work addresses spurious correlations in graph data that hamper GNN generalization by introducing CNL-GNN, a framework that imposes causal interventions on graph structure to learn invariant representations. It combines Counterfactual Neighbourhood Generation, Adaptive Causal Edge Generation, Edge Importance Estimation, and Causal Feature Disentanglement with a contrastive objective between original and counterfactual graphs. Experiments on four datasets, including Twitch variants, show state-of-the-art performance and strong robustness to distribution shifts, validating the approach. The results extend causal graph learning beyond feature-based methods and point to future work on dynamic graphs and more complex structural interventions.

Abstract

Graph data often contain noisy and spurious correlations that mask the true causal relationships, which are essential for enabling graph models to make predictions based on the underlying causal structure of the data. Dependence on spurious connections makes it challenging for traditional Graph Neural Networks (GNNs) to generalize effectively across different graphs. Furthermore, traditional aggregation methods tend to amplify these spurious patterns, limiting model robustness under distribution shifts. To address these issues, we propose Causal Neighbourhood Learning with Graph Neural Networks (CNL-GNN), a novel framework that performs causal interventions on graph structure. CNL-GNN effectively identifies and preserves causally relevant connections and reduces spurious influences through the generation of counterfactual neighbourhoods and adaptive edge perturbation guided by learnable importance masking and an attention-based mechanism. In addition, by combining structural-level interventions with the disentanglement of causal features from confounding factors, the model learns invariant node representations that are robust and generalize well across different graph structures. Our approach improves causal graph learning beyond traditional feature-based methods, resulting in a robust classification model. Extensive experiments on four publicly available datasets, including multiple domain variants of one dataset, demonstrate that CNL-GNN outperforms state-of-the-art GNN models.

Causal Neighbourhood Learning for Invariant Graph Representations

TL;DR

This work addresses spurious correlations in graph data that hamper GNN generalization by introducing CNL-GNN, a framework that imposes causal interventions on graph structure to learn invariant representations. It combines Counterfactual Neighbourhood Generation, Adaptive Causal Edge Generation, Edge Importance Estimation, and Causal Feature Disentanglement with a contrastive objective between original and counterfactual graphs. Experiments on four datasets, including Twitch variants, show state-of-the-art performance and strong robustness to distribution shifts, validating the approach. The results extend causal graph learning beyond feature-based methods and point to future work on dynamic graphs and more complex structural interventions.

Abstract

Graph data often contain noisy and spurious correlations that mask the true causal relationships, which are essential for enabling graph models to make predictions based on the underlying causal structure of the data. Dependence on spurious connections makes it challenging for traditional Graph Neural Networks (GNNs) to generalize effectively across different graphs. Furthermore, traditional aggregation methods tend to amplify these spurious patterns, limiting model robustness under distribution shifts. To address these issues, we propose Causal Neighbourhood Learning with Graph Neural Networks (CNL-GNN), a novel framework that performs causal interventions on graph structure. CNL-GNN effectively identifies and preserves causally relevant connections and reduces spurious influences through the generation of counterfactual neighbourhoods and adaptive edge perturbation guided by learnable importance masking and an attention-based mechanism. In addition, by combining structural-level interventions with the disentanglement of causal features from confounding factors, the model learns invariant node representations that are robust and generalize well across different graph structures. Our approach improves causal graph learning beyond traditional feature-based methods, resulting in a robust classification model. Extensive experiments on four publicly available datasets, including multiple domain variants of one dataset, demonstrate that CNL-GNN outperforms state-of-the-art GNN models.
Paper Structure (27 sections, 5 equations, 5 figures, 2 tables, 2 algorithms)

This paper contains 27 sections, 5 equations, 5 figures, 2 tables, 2 algorithms.

Figures (5)

  • Figure 1: Illustration of causal concepts using SCMs. (a) cause C influences effect E. (b) sets E to a fixed value. (c) modifies function or noise affecting E. ($N{_E}$: noise) (d) updates noise to evaluate alternative outcomes without changing causal structure. (T=1, B=1 : observed treatment and outcome NB=1: inferred latent condition)
  • Figure 2: CNL-GNN Architecture: (i) Enforces structure-invariant representations by perturbing neighbourhoods with dissimilar neighbours. (ii) Estimates edge relevance using attention to prioritize causal edges and guide structural perturbations. (iii) Selectively perturbs inter-group edges and masks low-importance edges to enhance causal robustness. (iv) Disentangles and fuses causal and non-causal features using gating for robust prediction.
  • Figure 3: Precision, Recall and F1 evaluation under distribution shifts across Twitch domains. Arrows indicate relative F1 drop compared to the DE domain.
  • Figure 4: F1 Scores across datasets (CNL-GNN vs. Ablated Variants). CNL-GNN bars are shown in blue, ablations are white with hatch patterns.
  • Figure 5: Sensitivity Analysis: $\Delta$ F1-score vs. Edge Drop Rate for Cora, Citeseer and Pubmed. Values indicate deviation from baseline at drop rate = 0.1.