Table of Contents
Fetching ...

Causal-aware Graph Neural Architecture Search under Distribution Shifts

Peiwen Li, Xin Wang, Zeyang Zhang, Yijian Qin, Ziwei Zhang, Jialong Wang, Yang Li, Wenwu Zhu

TL;DR

This work tackles the problem of distribution shifts in Graph Neural Architecture Search (Graph NAS) by learning a stable causal relation between graphs and their optimal architectures. It introduces CARNAS, a three-module framework comprising Disentangled Causal Subgraph Identification, Graph Embedding Intervention, and Invariant Architecture Customization to extract causal subgraphs, intervene in latent space, and tailor generalized architectures. The optimization blends prediction performance with causal invariance, encouraging architectures that generalize across environments. Empirical results on synthetic and real-world datasets show strong out-of-distribution generalization and ablations confirm the utility of each module.

Abstract

Graph NAS has emerged as a promising approach for autonomously designing GNN architectures by leveraging the correlations between graphs and architectures. Existing methods fail to generalize under distribution shifts that are ubiquitous in real-world graph scenarios, mainly because the graph-architecture correlations they exploit might be spurious and varying across distributions. We propose to handle the distribution shifts in the graph architecture search process by discovering and exploiting the causal relationship between graphs and architectures to search for the optimal architectures that can generalize under distribution shifts. The problem remains unexplored with following challenges: how to discover the causal graph-architecture relationship that has stable predictive abilities across distributions, and how to handle distribution shifts with the discovered causal graph-architecture relationship to search the generalized graph architectures. To address these challenges, we propose Causal-aware Graph Neural Architecture Search (CARNAS), which is able to capture the causal graph-architecture relationship during the architecture search process and discover the generalized graph architecture under distribution shifts. Specifically, we propose Disentangled Causal Subgraph Identification to capture the causal subgraphs that have stable prediction abilities across distributions. Then, we propose Graph Embedding Intervention to intervene on causal subgraphs within the latent space, ensuring that these subgraphs encapsulate essential features for prediction while excluding non-causal elements. Additionally, we propose Invariant Architecture Customization to reinforce the causal invariant nature of the causal subgraphs, which are utilized to tailor generalized graph architectures. Extensive experiments demonstrate that CARNAS achieves advanced out-of-distribution generalization ability.

Causal-aware Graph Neural Architecture Search under Distribution Shifts

TL;DR

This work tackles the problem of distribution shifts in Graph Neural Architecture Search (Graph NAS) by learning a stable causal relation between graphs and their optimal architectures. It introduces CARNAS, a three-module framework comprising Disentangled Causal Subgraph Identification, Graph Embedding Intervention, and Invariant Architecture Customization to extract causal subgraphs, intervene in latent space, and tailor generalized architectures. The optimization blends prediction performance with causal invariance, encouraging architectures that generalize across environments. Empirical results on synthetic and real-world datasets show strong out-of-distribution generalization and ablations confirm the utility of each module.

Abstract

Graph NAS has emerged as a promising approach for autonomously designing GNN architectures by leveraging the correlations between graphs and architectures. Existing methods fail to generalize under distribution shifts that are ubiquitous in real-world graph scenarios, mainly because the graph-architecture correlations they exploit might be spurious and varying across distributions. We propose to handle the distribution shifts in the graph architecture search process by discovering and exploiting the causal relationship between graphs and architectures to search for the optimal architectures that can generalize under distribution shifts. The problem remains unexplored with following challenges: how to discover the causal graph-architecture relationship that has stable predictive abilities across distributions, and how to handle distribution shifts with the discovered causal graph-architecture relationship to search the generalized graph architectures. To address these challenges, we propose Causal-aware Graph Neural Architecture Search (CARNAS), which is able to capture the causal graph-architecture relationship during the architecture search process and discover the generalized graph architecture under distribution shifts. Specifically, we propose Disentangled Causal Subgraph Identification to capture the causal subgraphs that have stable prediction abilities across distributions. Then, we propose Graph Embedding Intervention to intervene on causal subgraphs within the latent space, ensuring that these subgraphs encapsulate essential features for prediction while excluding non-causal elements. Additionally, we propose Invariant Architecture Customization to reinforce the causal invariant nature of the causal subgraphs, which are utilized to tailor generalized graph architectures. Extensive experiments demonstrate that CARNAS achieves advanced out-of-distribution generalization ability.
Paper Structure (56 sections, 2 theorems, 27 equations, 11 figures, 9 tables, 1 algorithm)

This paper contains 56 sections, 2 theorems, 27 equations, 11 figures, 9 tables, 1 algorithm.

Key Result

Theorem 1

A generator $f_C(G)$ is the optimal generator that satisfies Assumption assump if and only if it is the maximal causal subgraph generator, i.e., where $\mathcal{F}_\mathcal{E}$ is the subgraph generator set with related to the random vector of all environments, and $I(\cdot; \cdot)$ is the mutual information between the optimal architecture $A^*$ and the generated causal subgraph.

Figures (11)

  • Figure 1: The framework of our proposed method CARNAS. As for an input graph $G$, the disentangled causal subgraph identification module abstracts its causal subgraph $G_c$ with disentangled GNN layers. Then, in the graph embedding intervention module, we conduct several interventions on $G_c$ with non-causal subgraphs in latent space and obtain $\mathcal{L}_{cpred}$ from the embedding of $G_c$ in the meanwhile. After that, the invariant architecture customization module aims to deal with distribution shift by customizing architecture from $G_c$ to attain $\hat{Y}$, $\mathcal{L}_{pred}$, and form $\mathcal{L}_{arch}$, $\mathcal{L}_{op}$ to further constrain the causal invariant property of $G_c$. Blue lines present the prediction approach and grey lines show other processes in the training stage. Additionally, green lines denote the updating process.
  • Figure 2: Results of ablation studies on synthetic datasets, where ‘w/o $\mathcal{L}_{arch}$’ removes $\mathcal{L}_{arch}$ from the overall loss in Eq. (\ref{['equ: loss']}), ‘w/o $\mathcal{L}_{cpred}$’ removes $\mathcal{L}_{cpred}$, and ‘w/o $\mathcal{L}_{arch}$ & $\mathcal{L}_{cpred}$’ removes both of them. The error bars report the standard deviations. Besides, the average and standard deviations of the best-performed baseline on each dataset are denoted as the dark and light thick dash lines respectively.
  • Figure 3: Results of ablation studies on SIDER, where ‘w/o $\mathcal{L}_{arch}$’ removes $\mathcal{L}_{arch}$ from the overall loss in Eq. (\ref{['equ: loss']}), ‘w/o $\mathcal{L}_{cpred}$’ removes $\mathcal{L}_{cpred}$, and ‘w/o $\mathcal{L}_{arch}$ & $\mathcal{L}_{cpred}$’ removes both of them. The error bars report the standard deviations. Besides, the average and standard deviations of the best-performed baseline on each dataset are denoted as the dark and light thick dash lines respectively.
  • Figure 4: Training process of synthetic datasets.
  • Figure 5: Changes of the two parts of loss.
  • ...and 6 more figures

Theorems & Definitions (2)

  • Theorem 1: Optimal Generator of Causal Subgraphs
  • Theorem 2