Table of Contents
Fetching ...

jaxFMM: An Adaptive, GPU-Parallel Implementation of the Fast Multipole Method in JAX

Robert Kraft, Florian Bruckner, Dieter Suess, Claas Abert

TL;DR

The paper tackles fast computation of N-body potentials for the Laplace kernel with highly non-uniform charge distributions. It presents jaxFMM, an extremely concise adaptive FMM implemented in JAX that leverages non-uniform hierarchy generation and GPU parallelism through just-in-time compilation. Key contributions include a compact (~600 lines) implementation, support for autodiff, and robust performance across uniform and non-uniform distributions, demonstrated via benchmarks and code examples. This work enables large-scale micromagnetic stray-field evaluations and opens avenues for differentiable, inverse-design, and machine-learning–related applications, with planned improvements such as on-the-fly rotation matrices and multi-GPU support.

Abstract

We introduce jaxFMM, an open-source, adaptive, highly parallel point-charge Fast Multipole Method implementation for the Laplace kernel written in JAX. It is based on a non-uniform refinement strategy, which results in extremely concise and simple code. Benchmarks show that the algorithm performs well even for highly non-uniform charge distributions. JaxFMM already massively speeds up stray-field computations in micromagnetics and with JAX features like autodiff, novel applications such as inverse-design problems and machine-learning tasks can be tackled with ease in the future.

jaxFMM: An Adaptive, GPU-Parallel Implementation of the Fast Multipole Method in JAX

TL;DR

The paper tackles fast computation of N-body potentials for the Laplace kernel with highly non-uniform charge distributions. It presents jaxFMM, an extremely concise adaptive FMM implemented in JAX that leverages non-uniform hierarchy generation and GPU parallelism through just-in-time compilation. Key contributions include a compact (~600 lines) implementation, support for autodiff, and robust performance across uniform and non-uniform distributions, demonstrated via benchmarks and code examples. This work enables large-scale micromagnetic stray-field evaluations and opens avenues for differentiable, inverse-design, and machine-learning–related applications, with planned improvements such as on-the-fly rotation matrices and multi-GPU support.

Abstract

We introduce jaxFMM, an open-source, adaptive, highly parallel point-charge Fast Multipole Method implementation for the Laplace kernel written in JAX. It is based on a non-uniform refinement strategy, which results in extremely concise and simple code. Benchmarks show that the algorithm performs well even for highly non-uniform charge distributions. JaxFMM already massively speeds up stray-field computations in micromagnetics and with JAX features like autodiff, novel applications such as inverse-design problems and machine-learning tasks can be tackled with ease in the future.

Paper Structure

This paper contains 19 sections, 26 equations, 13 figures, 1 table.

Figures (13)

  • Figure 1: Schematic 2D-FMM setup: Combining multipoles (blue) into parent via M2M, adding them to locals (orange) via M2L and distributing locals to children via L2L.
  • Figure 2: From left to right: Double helix point distribution, jaxFMM hierarchy levels $0-3$ ($s=3$ splits per box).
  • Figure 3: Same-level (top) vs cross-level (bottom) interaction lists, for a given box (red): Strong coupling (orange), weak coupling at decreasing target level $l_\text{trg}$ (white -- dark blue).
  • Figure 4: Level 4 of the uniform cube hierarchy (top) and source points (bottom) for $N=2^{20}$.
  • Figure 5: Uniform cube performance on CPU (dashed lines) and GPU (solid lines).
  • ...and 8 more figures