jaxFMM: An Adaptive, GPU-Parallel Implementation of the Fast Multipole Method in JAX
Robert Kraft, Florian Bruckner, Dieter Suess, Claas Abert
TL;DR
The paper tackles fast computation of N-body potentials for the Laplace kernel with highly non-uniform charge distributions. It presents jaxFMM, an extremely concise adaptive FMM implemented in JAX that leverages non-uniform hierarchy generation and GPU parallelism through just-in-time compilation. Key contributions include a compact (~600 lines) implementation, support for autodiff, and robust performance across uniform and non-uniform distributions, demonstrated via benchmarks and code examples. This work enables large-scale micromagnetic stray-field evaluations and opens avenues for differentiable, inverse-design, and machine-learning–related applications, with planned improvements such as on-the-fly rotation matrices and multi-GPU support.
Abstract
We introduce jaxFMM, an open-source, adaptive, highly parallel point-charge Fast Multipole Method implementation for the Laplace kernel written in JAX. It is based on a non-uniform refinement strategy, which results in extremely concise and simple code. Benchmarks show that the algorithm performs well even for highly non-uniform charge distributions. JaxFMM already massively speeds up stray-field computations in micromagnetics and with JAX features like autodiff, novel applications such as inverse-design problems and machine-learning tasks can be tackled with ease in the future.
