FuseSampleAgg: Fused Neighbor Sampling and Aggregation for Mini-batch GNNs
Aleksandar Stanković
TL;DR
This work targets the memory and latency bottlenecks of mini-batch GraphSAGE training by fusing neighbor sampling and mean aggregation into a single CUDA operator, eliminating block materialization. The FuseSampleAgg kernel supports 1-hop and 2-hop graphs with optional saved indices for exact autograd replay, delivering deterministic behavior and integration with PyTorch. Empirical results across Reddit, ogbn-arxiv, and ogbn-products show substantial speedups (up to 51x) and large peak-memory reductions (up to 100x), enabling larger batch sizes and fanouts in practice. By bridging the sampling and aggregation boundary, the approach reduces memory traffic and kernel-launch overhead, providing a practical path toward more memory-efficient and high-throughput GNN training, with reproducible artifacts and open-source code for broader adoption and extension.
Abstract
We present FuseSampleAgg, a CUDA operator that fuses neighbor sampling and mean aggregation into a single pass for one and two hop GraphSAGE. By eliminating block materialization and extra kernel launches, FuseSampleAgg reduces memory traffic and overhead while preserving GraphSAGE mean semantics via saved index replay. Across the Reddit, ogbn-arxiv, and ogbn-products benchmarks (batch size 1024, automatic mixed precision enabled), we observe step time speedups up to 51x on ogbn-products, about 4x on Reddit with fanouts 10-10 and 15-10, and about 3.3x on ogbn-arxiv at larger fanouts, with peak GPU memory reductions up to 100x, 36x, and about 3.5x, respectively. The operator is deterministic, integrates with standard PyTorch optimizers, and ships with scripts that reproduce all tables and figures from CSV logs. Code and scripts are available at https://github.com/SV25-22/FuseSampleAgg.
