Table of Contents
Fetching ...

DIVE: Subgraph Disagreement for Graph Out-of-Distribution Generalization

Xin Sun, Liang Wang, Qiang Liu, Shu Wu, Zilei Wang, Liang Wang

TL;DR

This work tackles graph out-of-distribution generalization by addressing the simplicity bias of SGD, which causes models to rely on simple, often spurious, subgraphs. It introduces DIVE, a framework that trains a collection of models to identify all label-predictive subgraphs by enforcing diversity on their subgraph masks through a Jaccard-like disagreement regularizer, and selects the best model via OOD validation. The approach yields strong OOD performance across five graph benchmarks (GOOD and DrugOOD), demonstrates improved subgraph extraction over prior methods, and shows that diversity regularization is crucial for achieving robust generalization. By enabling discovery of both simple and complex predictive patterns and selecting a robust predictor, DIVE offers a practical path toward reliable graph learning under distribution shifts.

Abstract

This paper addresses the challenge of out-of-distribution (OOD) generalization in graph machine learning, a field rapidly advancing yet grappling with the discrepancy between source and target data distributions. Traditional graph learning algorithms, based on the assumption of uniform distribution between training and test data, falter in real-world scenarios where this assumption fails, resulting in suboptimal performance. A principal factor contributing to this suboptimal performance is the inherent simplicity bias of neural networks trained through Stochastic Gradient Descent (SGD), which prefer simpler features over more complex yet equally or more predictive ones. This bias leads to a reliance on spurious correlations, adversely affecting OOD performance in various tasks such as image recognition, natural language understanding, and graph classification. Current methodologies, including subgraph-mixup and information bottleneck approaches, have achieved partial success but struggle to overcome simplicity bias, often reinforcing spurious correlations. To tackle this, we propose DIVE, training a collection of models to focus on all label-predictive subgraphs by encouraging the models to foster divergence on the subgraph mask, which circumvents the limitation of a model solely focusing on the subgraph corresponding to simple structural patterns. Specifically, we employs a regularizer to punish overlap in extracted subgraphs across models, thereby encouraging different models to concentrate on distinct structural patterns. Model selection for robust OOD performance is achieved through validation accuracy. Tested across four datasets from GOOD benchmark and one dataset from DrugOOD benchmark, our approach demonstrates significant improvement over existing methods, effectively addressing the simplicity bias and enhancing generalization in graph machine learning.

DIVE: Subgraph Disagreement for Graph Out-of-Distribution Generalization

TL;DR

This work tackles graph out-of-distribution generalization by addressing the simplicity bias of SGD, which causes models to rely on simple, often spurious, subgraphs. It introduces DIVE, a framework that trains a collection of models to identify all label-predictive subgraphs by enforcing diversity on their subgraph masks through a Jaccard-like disagreement regularizer, and selects the best model via OOD validation. The approach yields strong OOD performance across five graph benchmarks (GOOD and DrugOOD), demonstrates improved subgraph extraction over prior methods, and shows that diversity regularization is crucial for achieving robust generalization. By enabling discovery of both simple and complex predictive patterns and selecting a robust predictor, DIVE offers a practical path toward reliable graph learning under distribution shifts.

Abstract

This paper addresses the challenge of out-of-distribution (OOD) generalization in graph machine learning, a field rapidly advancing yet grappling with the discrepancy between source and target data distributions. Traditional graph learning algorithms, based on the assumption of uniform distribution between training and test data, falter in real-world scenarios where this assumption fails, resulting in suboptimal performance. A principal factor contributing to this suboptimal performance is the inherent simplicity bias of neural networks trained through Stochastic Gradient Descent (SGD), which prefer simpler features over more complex yet equally or more predictive ones. This bias leads to a reliance on spurious correlations, adversely affecting OOD performance in various tasks such as image recognition, natural language understanding, and graph classification. Current methodologies, including subgraph-mixup and information bottleneck approaches, have achieved partial success but struggle to overcome simplicity bias, often reinforcing spurious correlations. To tackle this, we propose DIVE, training a collection of models to focus on all label-predictive subgraphs by encouraging the models to foster divergence on the subgraph mask, which circumvents the limitation of a model solely focusing on the subgraph corresponding to simple structural patterns. Specifically, we employs a regularizer to punish overlap in extracted subgraphs across models, thereby encouraging different models to concentrate on distinct structural patterns. Model selection for robust OOD performance is achieved through validation accuracy. Tested across four datasets from GOOD benchmark and one dataset from DrugOOD benchmark, our approach demonstrates significant improvement over existing methods, effectively addressing the simplicity bias and enhancing generalization in graph machine learning.
Paper Structure (35 sections, 10 equations, 6 figures, 3 tables, 1 algorithm)

This paper contains 35 sections, 10 equations, 6 figures, 3 tables, 1 algorithm.

Figures (6)

  • Figure 1: Overall framework of our method when the size of collections is two. The green subgraph (wheel pattern) and the blue subgraph (house pattern) are all label-predictive subgraphs and there exists a strong spurious correlation between these two structrual patterns. We train two models and impose them to attend to different label-predictive subgraph patterns using diversity regularization.
  • Figure 2: Visualization of the subgraph masks generated by different models in the collections. We train two model using our algorithm on GOODMotif dataset (basis-concept setting) and visualize the subgraph extracted by each model on the test set. Nodes colored pink are ground-truth subgraph nodes and each column represents a graph class. Subfigures (a) and (b), located as the identical position, correspond to each other and represent the same graph instance. It can be observed that model 0 attends to the correct subgrah while model 1 attends to the spurious one.
  • Figure 3: F1 curve of the subgraph mask prediction. For each method, we run the experiment for 5 times and the shadowed area represents standard deviation.
  • Figure 4: Distribution of the subgraph mask precision and recall of different models in the collection.
  • Figure 5: Performance using different $\lambda$ on GOODZINC and GOODSST2. We conduct the experiment 5 times for each $\lambda$ and the grey shaded area represents standard deviation.
  • ...and 1 more figures