$\texttt{FlashSchNet}$: Fast and Accurate Coarse-Grained Neural Network Molecular Dynamics
Pingzhi Li ⋅ Hongxuan Li ⋅ Zirui Liu ⋅ Xingcheng Lin ⋅ Tianlong Chen
Abstract
Graph neural network (GNN) potentials such as SchNet improve the accuracy and transferability of molecular dynamics (MD) simulation by learning many-body interactions, but remain slower than classical force fields due to fragmented kernels and memory-bound pipelines that underutilize GPUs. We show that a missing principle is making GNN-MD $\textit{IO-aware}$, carefully accounting for reads and writes between GPU high-bandwidth memory (HBM) and on-chip SRAM. We present $\texttt{FlashSchNet}$, an efficient and accurate IO-aware SchNet-style GNN-MD framework built on four techniques: (1) $\textit{flash radial basis}$, which fuses pairwise distance computation, Gaussian basis expansion, and cosine envelope into a single tiled pass, computing each distance once and reusing it across all basis functions; (2) $\textit{flash message passing}$, which fuses cutoff, neighbor gather, filter multiplication, and reduction to avoid materializing edge tensors in HBM; (3) $\textit{flash aggregation}$, which reformulates scatter-add via CSR segment reduce, reducing atomic writes by a factor of feature dimension and enabling contention-free accumulation in both forward and backward passes; (4) channel-wise 16-bit quantization that exploits the low per-channel dynamic range in SchNet MLP weights to further improve throughput with negligible accuracy loss. On a single NVIDIA RTX PRO 6000, $\texttt{FlashSchNet}$ achieves $\textbf{1000 ns/day}$ aggregate simulation throughput over 64 parallel replicas on coarse-grained (CG) protein containing 269 beads ($\textbf{6.5}$ $\mathbf{\times}$ faster than CGSchNet baseline with $\textbf{80\\% less}$ peak memory), surpassing widely used classical force fields ($\textit{e.g.}$, MARTINI) while retaining SchNet-level accuracy and transferability.
Successful Page Load