msmJAX: Fast and Differentiable Electrostatics on the GPU in Python
Florian Buchner, Johannes Schörghuber, Nico Unglert, Jesús Carrete, Georg K. H. Madsen
TL;DR
msmJAX tackles the costly problem of evaluating long-range electrostatics in atomistic simulations by implementing the multilevel summation method on GPUs within the JAX framework, enabling differentiable energies and forces. The approach splits the Coulomb kernel into level-specific partial kernels and uses a grid-based long-range evaluation combined with short-range direct summation, achieving theoretical $O(N)$ scaling in practice. Key contributions include a modular, Pythonic design with a high-level API, a B-spline interpolation implementation, flexible handling of periodicity and cell geometry, and optimized derivative computations via custom JVPs. The results demonstrate accurate Madelung constants, scalable performance up to $N \approx 10^6$ particles, and stability in MD simulations, showing msmJAX’s potential as a differentiable electrostatics core for ML-interfaced interatomic potentials.
Abstract
We present msmJAX, a Python package implementing the multilevel summation method with B-spline interpolation, a linear-scaling algorithm for efficiently evaluating electrostatic and other long-range interactions in particle-based simulations. Built on the JAX framework, msmJAX integrates naturally with the machine-learning methods that are transforming chemistry and materials science, while also serving as a powerful tool in its own right. It combines high performance with Python's accessibility, offers easy deployment on GPUs, and supports automatic differentiation. We outline the modular design of msmJAX, enabling users to adapt or extend the code, and present benchmarks and examples, including a verification of linear scaling, and demonstrations of its stability in molecular-dynamics simulations.
