Table of Contents
Fetching ...

Rethinking Graph Generalization through the Lens of Sharpness-Aware Minimization

Yang Qiu, Yixiong Zou, Jun Wang

TL;DR

Graph neural networks face distribution shifts that cause Minimal Shift Flip (MSF), where near-ID test samples are misclassified. The authors recast MSF through Sharpness-Aware Minimization (SAM), introducing the Local Robust Radius $r(x)$ and linking it to an energy score $E(x)$ so that flatness and stability can be optimized tractably; they prove a monotonic relationship between $E(x)$ and $r(x)$. They then propose Energy-driven Dual-stage Augmentation (E2A), which uses a conditional variational autoencoder (cVAE) to generate pseudo-ID samples and energy-guided latent perturbations to create pseudo-OOD samples, followed by an energy-based calibration objective to enlarge the local robust radius. Across GraphOOD and DrugOOD benchmarks, E2A achieves state-of-the-art or competitive OOD generalization while preserving in-distribution accuracy, without requiring environment labels. This work connects loss-landscape sharpness, energy-based measures, and latent-space perturbations to deliver a principled, scalable framework for robust graph learning under distribution shifts.

Abstract

Graph Neural Networks (GNNs) have achieved remarkable success across various graph-based tasks but remain highly sensitive to distribution shifts. In this work, we focus on a prevalent yet under-explored phenomenon in graph generalization, Minimal Shift Flip (MSF),where test samples that slightly deviate from the training distribution are abruptly misclassified. To interpret this phenomenon, we revisit MSF through the lens of Sharpness-Aware Minimization (SAM), which characterizes the local stability and sharpness of the loss landscape while providing a theoretical foundation for modeling generalization error. To quantify loss sharpness, we introduce the concept of Local Robust Radius, measuring the smallest perturbation required to flip a prediction and establishing a theoretical link between local stability and generalization. Building on this perspective, we further observe a continual decrease in the robust radius during training, indicating weakened local stability and an increasingly sharp loss landscape that gives rise to MSF. To jointly solve the MSF phenomenon and the intractability of radius, we develop an energy-based formulation that is theoretically proven to be monotonically correlated with the robust radius, offering a tractable and principled objective for modeling flatness and stability. Building on these insights, we propose an energy-driven generative augmentation framework (E2A) that leverages energy-guided latent perturbations to generate pseudo-OOD samples and enhance model generalization. Extensive experiments across multiple benchmarks demonstrate that E2A consistently improves graph OOD generalization, outperforming state-of-the-art baselines.

Rethinking Graph Generalization through the Lens of Sharpness-Aware Minimization

TL;DR

Graph neural networks face distribution shifts that cause Minimal Shift Flip (MSF), where near-ID test samples are misclassified. The authors recast MSF through Sharpness-Aware Minimization (SAM), introducing the Local Robust Radius and linking it to an energy score so that flatness and stability can be optimized tractably; they prove a monotonic relationship between and . They then propose Energy-driven Dual-stage Augmentation (E2A), which uses a conditional variational autoencoder (cVAE) to generate pseudo-ID samples and energy-guided latent perturbations to create pseudo-OOD samples, followed by an energy-based calibration objective to enlarge the local robust radius. Across GraphOOD and DrugOOD benchmarks, E2A achieves state-of-the-art or competitive OOD generalization while preserving in-distribution accuracy, without requiring environment labels. This work connects loss-landscape sharpness, energy-based measures, and latent-space perturbations to deliver a principled, scalable framework for robust graph learning under distribution shifts.

Abstract

Graph Neural Networks (GNNs) have achieved remarkable success across various graph-based tasks but remain highly sensitive to distribution shifts. In this work, we focus on a prevalent yet under-explored phenomenon in graph generalization, Minimal Shift Flip (MSF),where test samples that slightly deviate from the training distribution are abruptly misclassified. To interpret this phenomenon, we revisit MSF through the lens of Sharpness-Aware Minimization (SAM), which characterizes the local stability and sharpness of the loss landscape while providing a theoretical foundation for modeling generalization error. To quantify loss sharpness, we introduce the concept of Local Robust Radius, measuring the smallest perturbation required to flip a prediction and establishing a theoretical link between local stability and generalization. Building on this perspective, we further observe a continual decrease in the robust radius during training, indicating weakened local stability and an increasingly sharp loss landscape that gives rise to MSF. To jointly solve the MSF phenomenon and the intractability of radius, we develop an energy-based formulation that is theoretically proven to be monotonically correlated with the robust radius, offering a tractable and principled objective for modeling flatness and stability. Building on these insights, we propose an energy-driven generative augmentation framework (E2A) that leverages energy-guided latent perturbations to generate pseudo-OOD samples and enhance model generalization. Extensive experiments across multiple benchmarks demonstrate that E2A consistently improves graph OOD generalization, outperforming state-of-the-art baselines.
Paper Structure (30 sections, 3 theorems, 27 equations, 7 figures, 3 tables, 1 algorithm)

This paper contains 30 sections, 3 theorems, 27 equations, 7 figures, 3 tables, 1 algorithm.

Key Result

proposition 1

Under Assumption assump:lipschitz, for any perturbation $\|\delta_w\|\le\rho$ in parameter space and any $\|\delta_x\|\le r$ in input space, the local loss variation satisfies Equating the worst-case loss increase in both domains yields a first-order correspondence between $\rho$ and $r$: Hence, the input-space local robustness radius$r(x)$ can be viewed as an input-space counterpart of the SAM

Figures (7)

  • Figure 1: Minimal Shift Flip (MSF) (Left) refers to the phenomenon where test samples exhibiting minimal distributional shifts from training samples are still misclassified. The Local Robust Radius quantifies the maximum perturbation in representation space that preserves a model's prediction, reflecting the sharpness of the loss landscape in the sense of Sharpness-Aware Minimization (SAM). When the loss landscape around minima is sharp (i.e., has narrow curvature), the robust radius becomes small, making nearby test samples more likely to cross the decision boundary and be misclassified. In contrast, flatter regions provide wider robust margins and greater stability against distributional shifts. $d_{\text{emb}}$ denotes the cosine similarity between representations.
  • Figure 2: (a) During training, accuracy on the training set continues to rise, while test accuracy peaks and then declines. (b) During training, the robust radius shrink accordingly, which indicates sharpness of the loss landscape and misclassification of test samples, driving the drop of test accuracy.
  • Figure 3: (a) Kernel density of energy scores on Motif-basis. Structural shifts arise as base graph types (e.g., stars, trees, ladders) in training rarely appear in validation and test sets, leading to higher energy values on unseen graphs. (b) Energy distribution on Motif-size. With increasing graph size from training to validation and test sets, energy scores rise consistently, indicating stronger distribution shifts.
  • Figure 4: Step 1 (Modeling): Train the GNN and classifier, and concurrently train the cVAE to generate pseudo-ID embeddings akin to the GNN’s in-distribution representations. Step 2 (Exploration and Calibration): Generate class-conditioned pseudo-ID embeddings with the cVAE and apply energy-based gradient ascent to shift the pseudo-ID embeddings toward out-of-distribution regions. Then fine-tune the model on the pseudo-OOD embeddings to refine its landscape flatness.
  • Figure 5: Parameter sensitivity analysis of E2A on the Motif-Size dataset. The top row shows how the setting of perturbation steps and the step size affects model performance, with the bottom row showing the effect of the parameter $\lambda$. The left and right panels displaying results on the in-distribution and out-of-distribution test sets respectively.
  • ...and 2 more figures

Theorems & Definitions (8)

  • definition 1
  • definition 2: Perturbation Radius and Flatness in SAM
  • definition 3
  • proposition 1: Coupling Between SAM Perturbation Radius and Local Robust Radius
  • definition 4: Classification Margin.
  • proposition 2: Margin–Radius Relation.
  • definition 5
  • theorem 1