Rethinking Graph Generalization through the Lens of Sharpness-Aware Minimization
Yang Qiu, Yixiong Zou, Jun Wang
TL;DR
Graph neural networks face distribution shifts that cause Minimal Shift Flip (MSF), where near-ID test samples are misclassified. The authors recast MSF through Sharpness-Aware Minimization (SAM), introducing the Local Robust Radius $r(x)$ and linking it to an energy score $E(x)$ so that flatness and stability can be optimized tractably; they prove a monotonic relationship between $E(x)$ and $r(x)$. They then propose Energy-driven Dual-stage Augmentation (E2A), which uses a conditional variational autoencoder (cVAE) to generate pseudo-ID samples and energy-guided latent perturbations to create pseudo-OOD samples, followed by an energy-based calibration objective to enlarge the local robust radius. Across GraphOOD and DrugOOD benchmarks, E2A achieves state-of-the-art or competitive OOD generalization while preserving in-distribution accuracy, without requiring environment labels. This work connects loss-landscape sharpness, energy-based measures, and latent-space perturbations to deliver a principled, scalable framework for robust graph learning under distribution shifts.
Abstract
Graph Neural Networks (GNNs) have achieved remarkable success across various graph-based tasks but remain highly sensitive to distribution shifts. In this work, we focus on a prevalent yet under-explored phenomenon in graph generalization, Minimal Shift Flip (MSF),where test samples that slightly deviate from the training distribution are abruptly misclassified. To interpret this phenomenon, we revisit MSF through the lens of Sharpness-Aware Minimization (SAM), which characterizes the local stability and sharpness of the loss landscape while providing a theoretical foundation for modeling generalization error. To quantify loss sharpness, we introduce the concept of Local Robust Radius, measuring the smallest perturbation required to flip a prediction and establishing a theoretical link between local stability and generalization. Building on this perspective, we further observe a continual decrease in the robust radius during training, indicating weakened local stability and an increasingly sharp loss landscape that gives rise to MSF. To jointly solve the MSF phenomenon and the intractability of radius, we develop an energy-based formulation that is theoretically proven to be monotonically correlated with the robust radius, offering a tractable and principled objective for modeling flatness and stability. Building on these insights, we propose an energy-driven generative augmentation framework (E2A) that leverages energy-guided latent perturbations to generate pseudo-OOD samples and enhance model generalization. Extensive experiments across multiple benchmarks demonstrate that E2A consistently improves graph OOD generalization, outperforming state-of-the-art baselines.
