Causal-aware Graph Neural Architecture Search under Distribution Shifts
Peiwen Li, Xin Wang, Zeyang Zhang, Yijian Qin, Ziwei Zhang, Jialong Wang, Yang Li, Wenwu Zhu
TL;DR
This work tackles the problem of distribution shifts in Graph Neural Architecture Search (Graph NAS) by learning a stable causal relation between graphs and their optimal architectures. It introduces CARNAS, a three-module framework comprising Disentangled Causal Subgraph Identification, Graph Embedding Intervention, and Invariant Architecture Customization to extract causal subgraphs, intervene in latent space, and tailor generalized architectures. The optimization blends prediction performance with causal invariance, encouraging architectures that generalize across environments. Empirical results on synthetic and real-world datasets show strong out-of-distribution generalization and ablations confirm the utility of each module.
Abstract
Graph NAS has emerged as a promising approach for autonomously designing GNN architectures by leveraging the correlations between graphs and architectures. Existing methods fail to generalize under distribution shifts that are ubiquitous in real-world graph scenarios, mainly because the graph-architecture correlations they exploit might be spurious and varying across distributions. We propose to handle the distribution shifts in the graph architecture search process by discovering and exploiting the causal relationship between graphs and architectures to search for the optimal architectures that can generalize under distribution shifts. The problem remains unexplored with following challenges: how to discover the causal graph-architecture relationship that has stable predictive abilities across distributions, and how to handle distribution shifts with the discovered causal graph-architecture relationship to search the generalized graph architectures. To address these challenges, we propose Causal-aware Graph Neural Architecture Search (CARNAS), which is able to capture the causal graph-architecture relationship during the architecture search process and discover the generalized graph architecture under distribution shifts. Specifically, we propose Disentangled Causal Subgraph Identification to capture the causal subgraphs that have stable prediction abilities across distributions. Then, we propose Graph Embedding Intervention to intervene on causal subgraphs within the latent space, ensuring that these subgraphs encapsulate essential features for prediction while excluding non-causal elements. Additionally, we propose Invariant Architecture Customization to reinforce the causal invariant nature of the causal subgraphs, which are utilized to tailor generalized graph architectures. Extensive experiments demonstrate that CARNAS achieves advanced out-of-distribution generalization ability.
