Table of Contents
Fetching ...

FuseSampleAgg: Fused Neighbor Sampling and Aggregation for Mini-batch GNNs

Aleksandar Stanković

TL;DR

This work targets the memory and latency bottlenecks of mini-batch GraphSAGE training by fusing neighbor sampling and mean aggregation into a single CUDA operator, eliminating block materialization. The FuseSampleAgg kernel supports 1-hop and 2-hop graphs with optional saved indices for exact autograd replay, delivering deterministic behavior and integration with PyTorch. Empirical results across Reddit, ogbn-arxiv, and ogbn-products show substantial speedups (up to 51x) and large peak-memory reductions (up to 100x), enabling larger batch sizes and fanouts in practice. By bridging the sampling and aggregation boundary, the approach reduces memory traffic and kernel-launch overhead, providing a practical path toward more memory-efficient and high-throughput GNN training, with reproducible artifacts and open-source code for broader adoption and extension.

Abstract

We present FuseSampleAgg, a CUDA operator that fuses neighbor sampling and mean aggregation into a single pass for one and two hop GraphSAGE. By eliminating block materialization and extra kernel launches, FuseSampleAgg reduces memory traffic and overhead while preserving GraphSAGE mean semantics via saved index replay. Across the Reddit, ogbn-arxiv, and ogbn-products benchmarks (batch size 1024, automatic mixed precision enabled), we observe step time speedups up to 51x on ogbn-products, about 4x on Reddit with fanouts 10-10 and 15-10, and about 3.3x on ogbn-arxiv at larger fanouts, with peak GPU memory reductions up to 100x, 36x, and about 3.5x, respectively. The operator is deterministic, integrates with standard PyTorch optimizers, and ships with scripts that reproduce all tables and figures from CSV logs. Code and scripts are available at https://github.com/SV25-22/FuseSampleAgg.

FuseSampleAgg: Fused Neighbor Sampling and Aggregation for Mini-batch GNNs

TL;DR

This work targets the memory and latency bottlenecks of mini-batch GraphSAGE training by fusing neighbor sampling and mean aggregation into a single CUDA operator, eliminating block materialization. The FuseSampleAgg kernel supports 1-hop and 2-hop graphs with optional saved indices for exact autograd replay, delivering deterministic behavior and integration with PyTorch. Empirical results across Reddit, ogbn-arxiv, and ogbn-products show substantial speedups (up to 51x) and large peak-memory reductions (up to 100x), enabling larger batch sizes and fanouts in practice. By bridging the sampling and aggregation boundary, the approach reduces memory traffic and kernel-launch overhead, providing a practical path toward more memory-efficient and high-throughput GNN training, with reproducible artifacts and open-source code for broader adoption and extension.

Abstract

We present FuseSampleAgg, a CUDA operator that fuses neighbor sampling and mean aggregation into a single pass for one and two hop GraphSAGE. By eliminating block materialization and extra kernel launches, FuseSampleAgg reduces memory traffic and overhead while preserving GraphSAGE mean semantics via saved index replay. Across the Reddit, ogbn-arxiv, and ogbn-products benchmarks (batch size 1024, automatic mixed precision enabled), we observe step time speedups up to 51x on ogbn-products, about 4x on Reddit with fanouts 10-10 and 15-10, and about 3.3x on ogbn-arxiv at larger fanouts, with peak GPU memory reductions up to 100x, 36x, and about 3.5x, respectively. The operator is deterministic, integrates with standard PyTorch optimizers, and ships with scripts that reproduce all tables and figures from CSV logs. Code and scripts are available at https://github.com/SV25-22/FuseSampleAgg.

Paper Structure

This paper contains 25 sections, 5 figures, 3 tables, 2 algorithms.

Figures (5)

  • Figure 1: Median step-time speedup of FuseSampleAgg over the best baseline for B=1024, AMP=on. Each panel is a dataset; bars vary fanout. The dashed line marks parity (1.0 ×). Higher is better. Call-out: On Reddit at 25--10, FuseSampleAgg is slower, consistent with profiler evidence of higher atomic contention and weaker cache locality at that fanout.
  • Figure 2: Throughput scaling with batch size on ogbn-products (fanout 15-10, AMP=on). FuseSampleAgg scales better with larger batches than the baseline (higher is better).
  • Figure 3: Median step time vs. fanout on ogbn-arxiv (B=1024, AMP=on). Larger fanouts amplify FuseSampleAgg’s advantage (lower is better).
  • Figure 4: Peak memory reduction: ratio of DGL to FSA (higher is better). Batch size = 1024, AMP on. Values are medians over 3 runs; peaks measured during the timed loop.
  • Figure 5: Absolute peak GPU memory (MB) on a log scale for DGL (left) and FSA (right) across fanouts (batch=1024, AMP on). Same runs as Fig. \ref{['fig:peak-mem-ratio']}.