Table of Contents
Fetching ...

Fast Graph Sharpness-Aware Minimization for Enhancing and Accelerating Few-Shot Node Classification

Yihong Luo, Yuhan Chen, Siya Qiu, Yiwei Wang, Chen Zhang, Yan Zhou, Xiaochun Cao, Jing Tang

TL;DR

A novel algorithm, Fast Graph Sharpness-Aware Minimization (FGSAM), that integrates the rapid training of Multi-Layer Perceptrons (MLPs) with the superior performance of GNNs and reutilizes the gradient from the perturbation phase to incorporate graph topology into the minimization process at almost zero additional cost.

Abstract

Graph Neural Networks (GNNs) have shown superior performance in node classification. However, GNNs perform poorly in the Few-Shot Node Classification (FSNC) task that requires robust generalization to make accurate predictions for unseen classes with limited labels. To tackle the challenge, we propose the integration of Sharpness-Aware Minimization (SAM)--a technique designed to enhance model generalization by finding a flat minimum of the loss landscape--into GNN training. The standard SAM approach, however, consists of two forward-backward steps in each training iteration, doubling the computational cost compared to the base optimizer (e.g., Adam). To mitigate this drawback, we introduce a novel algorithm, Fast Graph Sharpness-Aware Minimization (FGSAM), that integrates the rapid training of Multi-Layer Perceptrons (MLPs) with the superior performance of GNNs. Specifically, we utilize GNNs for parameter perturbation while employing MLPs to minimize the perturbed loss so that we can find a flat minimum with good generalization more efficiently. Moreover, our method reutilizes the gradient from the perturbation phase to incorporate graph topology into the minimization process at almost zero additional cost. To further enhance training efficiency, we develop FGSAM+ that executes exact perturbations periodically. Extensive experiments demonstrate that our proposed algorithm outperforms the standard SAM with lower computational costs in FSNC tasks. In particular, our FGSAM+ as a SAM variant offers a faster optimization than the base optimizer in most cases. In addition to FSNC, our proposed methods also demonstrate competitive performance in the standard node classification task for heterophilic graphs, highlighting the broad applicability. The code is available at https://github.com/draym28/FGSAM_NeurIPS24.

Fast Graph Sharpness-Aware Minimization for Enhancing and Accelerating Few-Shot Node Classification

TL;DR

A novel algorithm, Fast Graph Sharpness-Aware Minimization (FGSAM), that integrates the rapid training of Multi-Layer Perceptrons (MLPs) with the superior performance of GNNs and reutilizes the gradient from the perturbation phase to incorporate graph topology into the minimization process at almost zero additional cost.

Abstract

Graph Neural Networks (GNNs) have shown superior performance in node classification. However, GNNs perform poorly in the Few-Shot Node Classification (FSNC) task that requires robust generalization to make accurate predictions for unseen classes with limited labels. To tackle the challenge, we propose the integration of Sharpness-Aware Minimization (SAM)--a technique designed to enhance model generalization by finding a flat minimum of the loss landscape--into GNN training. The standard SAM approach, however, consists of two forward-backward steps in each training iteration, doubling the computational cost compared to the base optimizer (e.g., Adam). To mitigate this drawback, we introduce a novel algorithm, Fast Graph Sharpness-Aware Minimization (FGSAM), that integrates the rapid training of Multi-Layer Perceptrons (MLPs) with the superior performance of GNNs. Specifically, we utilize GNNs for parameter perturbation while employing MLPs to minimize the perturbed loss so that we can find a flat minimum with good generalization more efficiently. Moreover, our method reutilizes the gradient from the perturbation phase to incorporate graph topology into the minimization process at almost zero additional cost. To further enhance training efficiency, we develop FGSAM+ that executes exact perturbations periodically. Extensive experiments demonstrate that our proposed algorithm outperforms the standard SAM with lower computational costs in FSNC tasks. In particular, our FGSAM+ as a SAM variant offers a faster optimization than the base optimizer in most cases. In addition to FSNC, our proposed methods also demonstrate competitive performance in the standard node classification task for heterophilic graphs, highlighting the broad applicability. The code is available at https://github.com/draym28/FGSAM_NeurIPS24.

Paper Structure

This paper contains 28 sections, 1 theorem, 17 equations, 5 figures, 12 tables, 2 algorithms.

Key Result

Theorem 4.1

Consider a K-classes CSBM, the optimal linear classifiers for both original features $\bm{x}_i$ and filtered features $\bm{h}_i$ are the same.

Figures (5)

  • Figure 1: Comparison of average accuracy and training time across datasets on different GNNs. The closer to the top left corner, the better.
  • Figure 2: (a): Loss landscape visualization of GNN across tasks and optimizers. (b): Loss of GNN, MLP and its PeerMLP on the test set over the training process. In these experiments, MLP and PeerMLP share the same weight space as GNN but are trained without message-passing.
  • Figure 3: Left (a): The solid line indicates that the gradient is computed on the corresponding model, while the dashed line indicates the opposite. Right (b): The difference of gradients (i.e., $\| \bm{g}_{t+1} - \bm{g}_{t} \|_2$). It can be seen that $\bm{g}_{v}$ and $\bm{g}_{\mathcal{G}}$ change much slower than $\bm{g}_{s}$ and $\bm{g}_{h}$ across the training process, thus can be reused in the intermediate steps.
  • Figure 4: Performance of GPN trained by Adam and FGSAM+ with different settings. Left: Results with various hidden channels. Middle Left: Results with various model depths. Middle Right: Results with features perturbed by noise of varying standard deviations. Right: Results with edges subjected to various noise ratios.
  • Figure 5: Training loss curves related to different $\rho$ across optimizers.

Theorems & Definitions (1)

  • Theorem 4.1: The effectiveness of removing MP in minimization