Table of Contents
Fetching ...

Out-of-Distribution Generalization on Graphs via Progressive Inference

Yiming Xu, Bin Shi, Zhen Peng, Huixiang Liu, Bo Dong, Chen Chen

TL;DR

The proposed GPro is a model that learns graph causal invariant learning is decomposed into multiple intermediate inference steps from easy to hard, and the perception of GPro is continuously strengthened through a progressive inference process to extract causal features that are stable to distribution shifts.

Abstract

The development and evaluation of graph neural networks (GNNs) generally follow the independent and identically distributed (i.i.d.) assumption. Yet this assumption is often untenable in practice due to the uncontrollable data generation mechanism. In particular, when the data distribution shows a significant shift, most GNNs would fail to produce reliable predictions and may even make decisions randomly. One of the most promising solutions to improve the model generalization is to pick out causal invariant parts in the input graph. Nonetheless, we observe a significant distribution gap between the causal parts learned by existing methods and the ground truth, leading to undesirable performance. In response to the above issues, this paper presents GPro, a model that learns graph causal invariance with progressive inference. Specifically, the complicated graph causal invariant learning is decomposed into multiple intermediate inference steps from easy to hard, and the perception of GPro is continuously strengthened through a progressive inference process to extract causal features that are stable to distribution shifts. We also enlarge the training distribution by creating counterfactual samples to enhance the capability of the GPro in capturing the causal invariant parts. Extensive experiments demonstrate that our proposed GPro outperforms the state-of-the-art methods by 4.91% on average. For datasets with more severe distribution shifts, the performance improvement can be up to 6.86%.

Out-of-Distribution Generalization on Graphs via Progressive Inference

TL;DR

The proposed GPro is a model that learns graph causal invariant learning is decomposed into multiple intermediate inference steps from easy to hard, and the perception of GPro is continuously strengthened through a progressive inference process to extract causal features that are stable to distribution shifts.

Abstract

The development and evaluation of graph neural networks (GNNs) generally follow the independent and identically distributed (i.i.d.) assumption. Yet this assumption is often untenable in practice due to the uncontrollable data generation mechanism. In particular, when the data distribution shows a significant shift, most GNNs would fail to produce reliable predictions and may even make decisions randomly. One of the most promising solutions to improve the model generalization is to pick out causal invariant parts in the input graph. Nonetheless, we observe a significant distribution gap between the causal parts learned by existing methods and the ground truth, leading to undesirable performance. In response to the above issues, this paper presents GPro, a model that learns graph causal invariance with progressive inference. Specifically, the complicated graph causal invariant learning is decomposed into multiple intermediate inference steps from easy to hard, and the perception of GPro is continuously strengthened through a progressive inference process to extract causal features that are stable to distribution shifts. We also enlarge the training distribution by creating counterfactual samples to enhance the capability of the GPro in capturing the causal invariant parts. Extensive experiments demonstrate that our proposed GPro outperforms the state-of-the-art methods by 4.91% on average. For datasets with more severe distribution shifts, the performance improvement can be up to 6.86%.

Paper Structure

This paper contains 30 sections, 16 equations, 7 figures, 8 tables, 1 algorithm.

Figures (7)

  • Figure 1: An illustration of the differences between existing methods and our proposed solution GPro. (a) The standard methods incorporate a significant amount of non-causal information (the green part in the input) in the learned features, resulting in a deviation from the decision boundary. (b) Our method is continuously refined via progressive inference to approach the ground truth.
  • Figure 2: The pipeline and implementation details of the GPro. The basic idea is to decompose the complex problem of causal invariant learning on graphs into multiple intermediate inference steps, and finally extract causal features with generalization through progressive inference. Notably, in the input graph toy example of the leftmost, the red part and the green part are defined as causal and non-causal substructures.
  • Figure 3: Quantitative sensitivity analysis of GPro for the number of progressive inference steps.
  • Figure 4: TSNE visualization of sample features of class 0 generated by the model in the CMNIST-75sp dataset. There is generally a significant distribution gap between the features learned by existing methods (such as GCN, DisC and CAL) and the ground-truth causal features. GPro learns causal features that are closer to the ground-truth via progressive inference.
  • Figure 5: TSNE visualization of the features learned by GCN, DisC, CAL, and GPro in the CMNIST-75sp dataset, where labels are marked by colors. The features learned through GPro show that the clusters within each category exhibit compactness while the distance between clusters is maximized.
  • ...and 2 more figures