Table of Contents
Fetching ...

Provably Convergent Subgraph-wise Sampling for Fast GNN Training

Jie Wang, Zhihao Shi, Xize Liang, Defu Lian, Shuiwang Ji, Bin Li, Enhong Chen, Feng Wu

TL;DR

This work tackles the challenge of scalable GNN training on large graphs by addressing gradient bias and convergence issues in subgraph-wise sampling. It introduces Local Message Compensation (LMC), a versatile framework that retrieves discarded messages through a backward-pass formulation and uses historical information to compensate both forward and backward passes, yielding unbiased gradient estimates and provable convergence to first-order stationary points for both ConvGNNs and RecGNNs. Theoretical results establish bias bounds that vanish with appropriate learning rates and iteration progress, while extensive experiments show LMC significantly accelerates training speed without compromising accuracy on large-scale datasets, with robust performance under small batch sizes. The proposed approach provides a practical, architecture-agnostic solution for fast, reliable GNN training, enabling scalable deployment in real-world applications requiring long-range dependency modeling.

Abstract

Subgraph-wise sampling -- a promising class of mini-batch training techniques for graph neural networks (GNNs -- is critical for real-world applications. During the message passing (MP) in GNNs, subgraph-wise sampling methods discard messages outside the mini-batches in backward passes to avoid the well-known neighbor explosion problem, i.e., the exponentially increasing dependencies of nodes with the number of MP iterations. However, discarding messages may sacrifice the gradient estimation accuracy, posing significant challenges to their convergence analysis and convergence speeds. To address this challenge, we propose a novel subgraph-wise sampling method with a convergence guarantee, namely Local Message Compensation (LMC). To the best of our knowledge, LMC is the first subgraph-wise sampling method with provable convergence. The key idea is to retrieve the discarded messages in backward passes based on a message passing formulation of backward passes. By efficient and effective compensations for the discarded messages in both forward and backward passes, LMC computes accurate mini-batch gradients and thus accelerates convergence. Moreover, LMC is applicable to various MP-based GNN architectures, including convolutional GNNs (finite message passing iterations with different layers) and recurrent GNNs (infinite message passing iterations with a shared layer). Experiments on large-scale benchmarks demonstrate that LMC is significantly faster than state-of-the-art subgraph-wise sampling methods.

Provably Convergent Subgraph-wise Sampling for Fast GNN Training

TL;DR

This work tackles the challenge of scalable GNN training on large graphs by addressing gradient bias and convergence issues in subgraph-wise sampling. It introduces Local Message Compensation (LMC), a versatile framework that retrieves discarded messages through a backward-pass formulation and uses historical information to compensate both forward and backward passes, yielding unbiased gradient estimates and provable convergence to first-order stationary points for both ConvGNNs and RecGNNs. Theoretical results establish bias bounds that vanish with appropriate learning rates and iteration progress, while extensive experiments show LMC significantly accelerates training speed without compromising accuracy on large-scale datasets, with robust performance under small batch sizes. The proposed approach provides a practical, architecture-agnostic solution for fast, reliable GNN training, enabling scalable deployment in real-world applications requiring long-range dependency modeling.

Abstract

Subgraph-wise sampling -- a promising class of mini-batch training techniques for graph neural networks (GNNs -- is critical for real-world applications. During the message passing (MP) in GNNs, subgraph-wise sampling methods discard messages outside the mini-batches in backward passes to avoid the well-known neighbor explosion problem, i.e., the exponentially increasing dependencies of nodes with the number of MP iterations. However, discarding messages may sacrifice the gradient estimation accuracy, posing significant challenges to their convergence analysis and convergence speeds. To address this challenge, we propose a novel subgraph-wise sampling method with a convergence guarantee, namely Local Message Compensation (LMC). To the best of our knowledge, LMC is the first subgraph-wise sampling method with provable convergence. The key idea is to retrieve the discarded messages in backward passes based on a message passing formulation of backward passes. By efficient and effective compensations for the discarded messages in both forward and backward passes, LMC computes accurate mini-batch gradients and thus accelerates convergence. Moreover, LMC is applicable to various MP-based GNN architectures, including convolutional GNNs (finite message passing iterations with different layers) and recurrent GNNs (infinite message passing iterations with a shared layer). Experiments on large-scale benchmarks demonstrate that LMC is significantly faster than state-of-the-art subgraph-wise sampling methods.
Paper Structure (70 sections, 28 theorems, 197 equations, 13 figures, 14 tables, 2 algorithms)

This paper contains 70 sections, 28 theorems, 197 equations, 13 figures, 14 tables, 2 algorithms.

Key Result

Theorem 1

Suppose that a subgraph $\mathcal{V}_{\mathcal{B}}$ is uniformly sampled from $\mathcal{V}$ and the corresponding labeled nodes $\mathcal{V}_{L_{\mathcal{B}}} = \mathcal{V}_{\mathcal{B}} \cap \mathcal{V}_{L}$ is uniformly sampled from $\mathcal{V}_{L}$. Then the mini-batch gradients $\mathbf{g}_w(\m

Figures (13)

  • Figure 1: The architectures of ConvGNNs and RecGNNs. We denote message passing by MP.
  • Figure 2: Comparison of LMC4Conv with GAS gas. (a) shows the original graph with in-batch nodes, 1-hop out-of-batch nodes, and other out-of-batch nodes in orange, blue, and grey, respectively. (b) and (d) show the computation graphs of forward passes and backward passes of GAS, respectively. (c) and (e) show the computation graphs of forward passes and backward passes of LMC4Conv, respectively.
  • Figure 3: Message passing of backward SGD, Cluster-GCN cluster_gcn, GAS gas, and LMC for RecGNNs.
  • Figure 4: Testing accuracy and training loss w.r.t. runtimes (s). The LMC in the figures refers to LMC4Conv.
  • Figure 5: The average relative estimated errors of mini-batch gradients computed by Cluster-GCN, GAS, and LMC4Conv for GCN.
  • ...and 8 more figures

Theorems & Definitions (52)

  • Theorem 1
  • Theorem 2
  • Theorem 3
  • Remark 1
  • Theorem 4
  • Theorem 5
  • proof
  • Lemma 1
  • proof
  • Lemma 2
  • ...and 42 more