Sampling-based Distributed Training with Message Passing Neural Network
Priyesh Kakka, Sheel Nidhan, Rishikesh Ranade, Jay Pathak, Jonathan F. MacArt
TL;DR
This work tackles the memory and scalability bottleneck of edge-based graph neural networks for PDE surrogates by introducing DS-MPNN, a domain-decomposition and Nyström-inspired sampling framework that distributes MPNN training across multiple GPUs. By constructing graph kernels around sampled centers with radius-based message passing and employing overlap regions for inter-GPU communication, DS-MPNN scales to roughly $O(10^5)$ nodes while maintaining accuracy comparable to a single-GPU MPNN and outperforming node-based GCN baselines. Across Darcy flow, AirfRANS, and 3D step flow experiments, DS-MPNN demonstrates robust performance and significant gains in training and inference speed, confirming the practicality of edge-based PDE surrogates on large, unstructured graphs. The approach lays groundwork for integrating advanced domain partitioning (e.g., METIS) and modern distributed graph libraries to further scale PDE-informed learning.
Abstract
In this study, we introduce a domain-decomposition-based distributed training and inference approach for message-passing neural networks (MPNN). Our objective is to address the challenge of scaling edge-based graph neural networks as the number of nodes increases. Through our distributed training approach, coupled with Nyström-approximation sampling techniques, we present a scalable graph neural network, referred to as DS-MPNN (D and S standing for distributed and sampled, respectively), capable of scaling up to $O(10^5)$ nodes. We validate our sampling and distributed training approach on two cases: (a) a Darcy flow dataset and (b) steady RANS simulations of 2-D airfoils, providing comparisons with both single-GPU implementation and node-based graph convolution networks (GCNs). The DS-MPNN model demonstrates comparable accuracy to single-GPU implementation, can accommodate a significantly larger number of nodes compared to the single-GPU variant (S-MPNN), and significantly outperforms the node-based GCN.
