Table of Contents
Fetching ...

CGRL: Causal-Guided Representation Learning for Graph Out-of-Distribution Generalization

Bowen Lu, Liangqiang Yang, Teng Li

Abstract

Graph Neural Networks (GNNs) have achieved impressive performance in graph-related tasks. However, they suffer from poor generalization on out-of-distribution (OOD) data, as they tend to learn spurious correlations. Such correlations present a phenomenon that GNNs fail to stably learn the mutual information between prediction representations and ground-truth labels under OOD settings. To address these challenges, we formulate a causal graph starting from the essence of node classification, adopt backdoor adjustment to block non-causal paths, and theoretically derive a lower bound for improving OOD generalization of GNNs. To materialize these insights, we further propose a novel approach integrating causal representation learning and a loss replacement strategy. The former captures node-level causal invariance and reconstructs graph posterior distribution. The latter introduces asymptotic losses of the same order to replace the original losses. Extensive experiments demonstrate the superiority of our method in OOD generalization and effectively alleviating the phenomenon of unstable mutual information learning.

CGRL: Causal-Guided Representation Learning for Graph Out-of-Distribution Generalization

Abstract

Graph Neural Networks (GNNs) have achieved impressive performance in graph-related tasks. However, they suffer from poor generalization on out-of-distribution (OOD) data, as they tend to learn spurious correlations. Such correlations present a phenomenon that GNNs fail to stably learn the mutual information between prediction representations and ground-truth labels under OOD settings. To address these challenges, we formulate a causal graph starting from the essence of node classification, adopt backdoor adjustment to block non-causal paths, and theoretically derive a lower bound for improving OOD generalization of GNNs. To materialize these insights, we further propose a novel approach integrating causal representation learning and a loss replacement strategy. The former captures node-level causal invariance and reconstructs graph posterior distribution. The latter introduces asymptotic losses of the same order to replace the original losses. Extensive experiments demonstrate the superiority of our method in OOD generalization and effectively alleviating the phenomenon of unstable mutual information learning.

Paper Structure

This paper contains 32 sections, 4 theorems, 33 equations, 9 figures, 4 tables.

Key Result

Theorem 3.1

Given a causal graph in Fig. Figure2, we can obtain a equation that estimate causal relationships from $\mathbf{H}_c$ to $\mathbf{Y}$: where $\theta$ represents parameter of the model.

Figures (9)

  • Figure 1: Mutual information (MI) between GCN-based prediction representations and ground-truth labels on the Cora dataset with and without feature shifts.
  • Figure 2: Causal graph.
  • Figure 3: The CGRL framework consists of two parts: causal representation learning and loss replacement strategy. The former includes re-weight representation learning (RRL) and energy-based reconstruction. They captures node-level invariance and performs graph reconstruction to yield the reconstruction loss $\mathcal{L}_{rec}$. Specifically, the adjacency matrix $A$ and node representation $\mathbf{X}$ are fed into a GNN encoder to learning representation $\mathbf{Z}$ (i.e., $P_\theta(\mathcal{G}_s)$). In the testing phase, $\mathbf{Z}$ is sequentially processed by the softmax and the RRL module, producing $\mathbf{H}_r$ and the prediction representation $\mathbf{H}_c$, respectively. In contrast, the training phase involves a sequence of operations on $\mathbf{Z}$: Gumbel trick, RRL, and energy-based reconstruction. The final $\mathbf{H}_c$ is fed into a classifier for prediction, yielding the supervised loss $\mathcal{L}_{sup}$. To achieve the optimization objectives of clustering intra-class and separating inter-class, the latter introduces asymptotic losses of the same order to replace the original losses, which gives rise to the intra-class loss $\mathcal{L}_{intra}$ and inter-class loss $\mathcal{L}_{inter}$.
  • Figure 4: Mutual information on Cora and Citeseer datasets with feature shifts between prediction representations and ground-truth labels based on CGRL. As the number of epochs grows, the value of mutual information exhibits a tendency of convergence.
  • Figure 5: Mutual information on GOODCora dataset with covariate shift in degree domain between prediction representations and ground-truth labels. The four plots on the left are based on the GCN-based model, while the four on the right are based on the GAT-based model.
  • ...and 4 more figures

Theorems & Definitions (4)

  • Theorem 3.1
  • Theorem 3.2
  • Lemma 3.3
  • Theorem 3.4