Causal Neighbourhood Learning for Invariant Graph Representations
Simi Job, Xiaohui Tao, Taotao Cai, Haoran Xie, Jianming Yong
TL;DR
This work addresses spurious correlations in graph data that hamper GNN generalization by introducing CNL-GNN, a framework that imposes causal interventions on graph structure to learn invariant representations. It combines Counterfactual Neighbourhood Generation, Adaptive Causal Edge Generation, Edge Importance Estimation, and Causal Feature Disentanglement with a contrastive objective between original and counterfactual graphs. Experiments on four datasets, including Twitch variants, show state-of-the-art performance and strong robustness to distribution shifts, validating the approach. The results extend causal graph learning beyond feature-based methods and point to future work on dynamic graphs and more complex structural interventions.
Abstract
Graph data often contain noisy and spurious correlations that mask the true causal relationships, which are essential for enabling graph models to make predictions based on the underlying causal structure of the data. Dependence on spurious connections makes it challenging for traditional Graph Neural Networks (GNNs) to generalize effectively across different graphs. Furthermore, traditional aggregation methods tend to amplify these spurious patterns, limiting model robustness under distribution shifts. To address these issues, we propose Causal Neighbourhood Learning with Graph Neural Networks (CNL-GNN), a novel framework that performs causal interventions on graph structure. CNL-GNN effectively identifies and preserves causally relevant connections and reduces spurious influences through the generation of counterfactual neighbourhoods and adaptive edge perturbation guided by learnable importance masking and an attention-based mechanism. In addition, by combining structural-level interventions with the disentanglement of causal features from confounding factors, the model learns invariant node representations that are robust and generalize well across different graph structures. Our approach improves causal graph learning beyond traditional feature-based methods, resulting in a robust classification model. Extensive experiments on four publicly available datasets, including multiple domain variants of one dataset, demonstrate that CNL-GNN outperforms state-of-the-art GNN models.
