Table of Contents
Fetching ...

FlashSchNet: Fast and Accurate Coarse-Grained Neural Network Molecular Dynamics

Pingzhi Li, Hongxuan Li, Zirui Liu, Xingcheng Lin, Tianlong Chen

TL;DR

FlashSchNet is presented, an efficient and accurate IO-aware SchNet-style GNN-MD framework built on four techniques, which achieves 1000 ns/day aggregate simulation throughput over 64 parallel replicas on coarse-grained protein containing 269 beads, surpassing classical force fields while retaining SchNet-level accuracy and transferability.

Abstract

Graph neural network (GNN) potentials such as SchNet improve the accuracy and transferability of molecular dynamics (MD) simulation by learning many-body interactions, but remain slower than classical force fields due to fragmented kernels and memory-bound pipelines that underutilize GPUs. We show that a missing principle is making GNN-MD IO-aware, carefully accounting for reads and writes between GPU high-bandwidth memory (HBM) and on-chip SRAM. We present FlashSchNet, an efficient and accurate IO-aware SchNet-style GNN-MD framework built on four techniques: (1) flash radial basis, which fuses pairwise distance computation, Gaussian basis expansion, and cosine envelope into a single tiled pass, computing each distance once and reusing it across all basis functions; (2) flash message passing, which fuses cutoff, neighbor gather, filter multiplication, and reduction to avoid materializing edge tensors in HBM; (3) flash aggregation, which reformulates scatter-add via CSR segment reduce, reducing atomic writes by a factor of feature dimension and enabling contention-free accumulation in both forward and backward passes; (4) channel-wise 16-bit quantization that exploits the low per-channel dynamic range in SchNet MLP weights to further improve throughput with negligible accuracy loss. On a single NVIDIA RTX PRO 6000, FlashSchNet achieves 1000 ns/day aggregate simulation throughput over 64 parallel replicas on coarse-grained (CG) protein containing 269 beads (6.5x faster than CGSchNet baseline with 80% reduction of peak memory), surpassing classical force fields (e.g. MARTINI) while retaining SchNet-level accuracy and transferability.

FlashSchNet: Fast and Accurate Coarse-Grained Neural Network Molecular Dynamics

TL;DR

FlashSchNet is presented, an efficient and accurate IO-aware SchNet-style GNN-MD framework built on four techniques, which achieves 1000 ns/day aggregate simulation throughput over 64 parallel replicas on coarse-grained protein containing 269 beads, surpassing classical force fields while retaining SchNet-level accuracy and transferability.

Abstract

Graph neural network (GNN) potentials such as SchNet improve the accuracy and transferability of molecular dynamics (MD) simulation by learning many-body interactions, but remain slower than classical force fields due to fragmented kernels and memory-bound pipelines that underutilize GPUs. We show that a missing principle is making GNN-MD IO-aware, carefully accounting for reads and writes between GPU high-bandwidth memory (HBM) and on-chip SRAM. We present FlashSchNet, an efficient and accurate IO-aware SchNet-style GNN-MD framework built on four techniques: (1) flash radial basis, which fuses pairwise distance computation, Gaussian basis expansion, and cosine envelope into a single tiled pass, computing each distance once and reusing it across all basis functions; (2) flash message passing, which fuses cutoff, neighbor gather, filter multiplication, and reduction to avoid materializing edge tensors in HBM; (3) flash aggregation, which reformulates scatter-add via CSR segment reduce, reducing atomic writes by a factor of feature dimension and enabling contention-free accumulation in both forward and backward passes; (4) channel-wise 16-bit quantization that exploits the low per-channel dynamic range in SchNet MLP weights to further improve throughput with negligible accuracy loss. On a single NVIDIA RTX PRO 6000, FlashSchNet achieves 1000 ns/day aggregate simulation throughput over 64 parallel replicas on coarse-grained (CG) protein containing 269 beads (6.5x faster than CGSchNet baseline with 80% reduction of peak memory), surpassing classical force fields (e.g. MARTINI) while retaining SchNet-level accuracy and transferability.
Paper Structure (53 sections, 13 equations, 7 figures, 3 tables)

This paper contains 53 sections, 13 equations, 7 figures, 3 tables.

Figures (7)

  • Figure 1: Left: Memory-throughput trade-off for SchNet-style GNN-MD. FlashSchNet achieves $5\times$ memory reduction while improving throughput by $6\times$ over CGSchNet baseline. Right: Step time breakdown showing FlashSchNet eliminates scatter and element-wise bottlenecks via fused kernels and 16-bit quantization. All are evaluated on a 269-bead protein (1ENH) with 64 replicas.
  • Figure 2: (a) SchNet model architecture for molecular dynamics: atom positions $\mathbf{r}$ and embeddings $\mathbf{X}$ are processed through neighbor list construction, radial basis expansion, and $T$ interaction blocks, followed by energy readout and autodiff for force computation. (b)FlashSchNet IO-aware execution model. The baseline pipeline (bottom, shaded) materializes intermediate edge tensors ($\mathbf{B} \in \mathbb{R}^{E \times D_r}$, $\mathbf{W}, \mathbf{X}_{\text{src}}, \mathbf{M} \in \mathbb{R}^{E \times D}$) to HBM and uses atomic scatter for aggregation. FlashSchNet (top, orange) fuses these operations into three kernels that keep intermediates in SRAM: Fused RBF computes distances, Gaussian basis expansion, and cosine envelope in one pass; Fused MP combines FP16 filter MLP, neighbor gather, and element-wise multiplication; Segmented reduce replaces atomic scatter-add with contention-free CSR-style accumulation. Red crosses indicate eliminated HBM materializations. The FlashSchNet pipeline reduces memory traffic by ${\sim}E/N$ and removes all atomic contention.
  • Figure 3: Filter networks show clear channel-wise magnitude distribution, motivating channel quantization for lossless acceleration.
  • Figure 4: Trajectories of C$\alpha$ RMSD and fraction of native contacts ($Q$) for three fast-folding proteins simulated with FlashSchNet. The plots demonstrate multiple reversible folding/unfolding events with the expected anti-correlation between RMSD and $Q$. FlashSchNet successfully captures the distinct folding timescales of Chignolin (nanosecond transitions) compared to the longer residence times of TRPcage and Villin.
  • Figure 5: Step-wise throughput comparison on 1ENH protein during 300k-step elongated simulation across three batch sizes (i.e.$16$, $32$, $64$ parallel replicas). FlashSchNet maintains consistent throughput along simulation despite evolving graph topology, while CGSchNet degrades as the neighbor graph becomes denser and less diagonal, as shown in Figure \ref{['fig:periodic-graph']}. The speedup gap widens with batch size, reaching $6.5\times$ at $64$ replicas.
  • ...and 2 more figures