PRES: Toward Scalable Memory-Based Dynamic Graph Neural Networks
Junwei Su, Difan Zou, Chuan Wu
TL;DR
Training memory-based dynamic graph neural networks (MDGNNs) is hindered by temporal discontinuity in batch processing, which disrupts chronological memory updates and limits data parallelism. The authors introduce PRES, an iterative prediction-correction framework with a memory-coherence smoothing objective, to enable significantly larger temporal batches without sacrificing performance. They provide theoretical insights on the impact of temporal batch size on variance and convergence, and demonstrate that PRES achieves up to a 4x increase in temporal batch size and about 3.4x speed-up on benchmarks. Practically, PRES extends the scalability of MDGNNs to industrial-scale dynamic graphs by improving training efficiency while preserving accuracy.
Abstract
Memory-based Dynamic Graph Neural Networks (MDGNNs) are a family of dynamic graph neural networks that leverage a memory module to extract, distill, and memorize long-term temporal dependencies, leading to superior performance compared to memory-less counterparts. However, training MDGNNs faces the challenge of handling entangled temporal and structural dependencies, requiring sequential and chronological processing of data sequences to capture accurate temporal patterns. During the batch training, the temporal data points within the same batch will be processed in parallel, while their temporal dependencies are neglected. This issue is referred to as temporal discontinuity and restricts the effective temporal batch size, limiting data parallelism and reducing MDGNNs' flexibility in industrial applications. This paper studies the efficient training of MDGNNs at scale, focusing on the temporal discontinuity in training MDGNNs with large temporal batch sizes. We first conduct a theoretical study on the impact of temporal batch size on the convergence of MDGNN training. Based on the analysis, we propose PRES, an iterative prediction-correction scheme combined with a memory coherence learning objective to mitigate the effect of temporal discontinuity, enabling MDGNNs to be trained with significantly larger temporal batches without sacrificing generalization performance. Experimental results demonstrate that our approach enables up to a 4x larger temporal batch (3.4x speed-up) during MDGNN training.
