Table of Contents
Fetching ...

msmJAX: Fast and Differentiable Electrostatics on the GPU in Python

Florian Buchner, Johannes Schörghuber, Nico Unglert, Jesús Carrete, Georg K. H. Madsen

TL;DR

msmJAX tackles the costly problem of evaluating long-range electrostatics in atomistic simulations by implementing the multilevel summation method on GPUs within the JAX framework, enabling differentiable energies and forces. The approach splits the Coulomb kernel into level-specific partial kernels and uses a grid-based long-range evaluation combined with short-range direct summation, achieving theoretical $O(N)$ scaling in practice. Key contributions include a modular, Pythonic design with a high-level API, a B-spline interpolation implementation, flexible handling of periodicity and cell geometry, and optimized derivative computations via custom JVPs. The results demonstrate accurate Madelung constants, scalable performance up to $N \approx 10^6$ particles, and stability in MD simulations, showing msmJAX’s potential as a differentiable electrostatics core for ML-interfaced interatomic potentials.

Abstract

We present msmJAX, a Python package implementing the multilevel summation method with B-spline interpolation, a linear-scaling algorithm for efficiently evaluating electrostatic and other long-range interactions in particle-based simulations. Built on the JAX framework, msmJAX integrates naturally with the machine-learning methods that are transforming chemistry and materials science, while also serving as a powerful tool in its own right. It combines high performance with Python's accessibility, offers easy deployment on GPUs, and supports automatic differentiation. We outline the modular design of msmJAX, enabling users to adapt or extend the code, and present benchmarks and examples, including a verification of linear scaling, and demonstrations of its stability in molecular-dynamics simulations.

msmJAX: Fast and Differentiable Electrostatics on the GPU in Python

TL;DR

msmJAX tackles the costly problem of evaluating long-range electrostatics in atomistic simulations by implementing the multilevel summation method on GPUs within the JAX framework, enabling differentiable energies and forces. The approach splits the Coulomb kernel into level-specific partial kernels and uses a grid-based long-range evaluation combined with short-range direct summation, achieving theoretical scaling in practice. Key contributions include a modular, Pythonic design with a high-level API, a B-spline interpolation implementation, flexible handling of periodicity and cell geometry, and optimized derivative computations via custom JVPs. The results demonstrate accurate Madelung constants, scalable performance up to particles, and stability in MD simulations, showing msmJAX’s potential as a differentiable electrostatics core for ML-interfaced interatomic potentials.

Abstract

We present msmJAX, a Python package implementing the multilevel summation method with B-spline interpolation, a linear-scaling algorithm for efficiently evaluating electrostatic and other long-range interactions in particle-based simulations. Built on the JAX framework, msmJAX integrates naturally with the machine-learning methods that are transforming chemistry and materials science, while also serving as a powerful tool in its own right. It combines high performance with Python's accessibility, offers easy deployment on GPUs, and supports automatic differentiation. We outline the modular design of msmJAX, enabling users to adapt or extend the code, and present benchmarks and examples, including a verification of linear scaling, and demonstrations of its stability in molecular-dynamics simulations.

Paper Structure

This paper contains 19 sections, 23 equations, 11 figures, 1 algorithm.

Figures (11)

  • Figure 1: Basic principles of the multilevel summation method. Top row: splitting of a long-range potential (Coulomb potential $1/r$ in this example) into a sum of partial kernels with increasing cutoffs that also get smoother and smoother. Bottom row: schematic evaluation in both cases. Whereas exact evaluation of the original long-range potential (to the left of the approximately-equal sign) requires evaluating all pairwise interactions, the MSM (to the right of the approximately-equal sign) evaluates only the shortest-range kernel $k_0(r)$ directly and approximates the others from grids with increasing spacing. Levels increase from left to right, red and blue points symbolize particles of opposite charge.
  • Figure 2: Abstract schematic of algorithm steps. The separation into short-range (evaluated directly) and long-range parts (approximated using grids) is highlighted, along with the corresponding modules of msmJAX. Rectangular boxes on the bottom left and right represent inputs and outputs, respectively, while round nodes in the upper part represent intermediate results defined on grids at different levels. The tapering of the "ladder" towards the top symbolizes the coarsening of subsequent grids. Arrows indicate the application of linear operators, and where two arrows end at the same node, their results are added together.
  • Figure 3: Package structure of msmJAX. The larger boxes with rounded edges represent either top-level module files (.py extension) or subpackages (no extension). Modules contained in these subpackages are in turn represented by the smaller rectangular boxes of the same color below them. The arrangement of components from bottom to top reflects a conceptual hierarchy of increasingly higher-level functionalities.
  • Figure 4: Simplified schematic of how the setup functions make_compute_u_zero and make_compute_u_oneplus provided in msmjax.core construct evaluation functions for $U^0$ and $U^{1+}$. The elements with a color fill are the inputs to the setup functions, in the choice of which lies the modularity of the design. Ellipses represent data and rectangles represent functions, with their inputs/outputs represented by incoming/outgoing arrows. In particular, the outermost enclosing rectangles correspond to the end-to-end evaluation functions that are constructed. Where blocks slot into other blocks in the style of jigsaw puzzle pieces, this emphasizes how one function internally serves as input to the construction of another, more complex one.
  • Figure 5: B-spline basis functions at two successive grid levels $l$ (bottom) and $l+1$ (top), for different interpolation orders $p$.
  • ...and 6 more figures