Table of Contents
Fetching ...

JAX-in-Cell: A Differentiable Particle-in-Cell Code for Plasma Physics Applications

Longyu Ma, Rogerio Jorge, Hongke Lu, Aaron Tran, Christopher Woolford

TL;DR

JAX-in-Cell presents a differentiable, fully electromagnetic 1D3V PIC implemented in the JAX framework, unifying explicit Boris push and implicit Crank-Nicolson schemes on CPU/GPU/TPU with a Python-based workflow. It solves the $\text{Vlasov–Maxwell}$ system on a staggered Yee lattice, enforces charge conservation via divergence cleaning, and enables end-to-end gradients through automatic differentiation for optimization tasks. The authors validate the code against Landau damping, the two-stream instability, the Weibel instability, and bump-on-tail dynamics, demonstrating accurate growth rates and energy conservation, and showcase autodiff-enabled optimization of beam parameters with six iterations. The work offers a practical, open-source platform that bridges educational scripts and production PIC codes, enabling gradient-based plasma optimization and AI integration with substantial hardware-accelerated performance.

Abstract

JAX-in-Cell is a fully electromagnetic, multispecies, and relativistic 1D3V Particle-in-Cell (PIC) framework implemented entirely in JAX. It provides a modern, Python-based alternative to traditional PIC frameworks. It leverages Just-In-Time compilation and automatic vectorization to achieve the performance of traditional compiled codes on CPUs, GPUs, and TPUs. The resulting framework bridges the gap between educational scripts and production codes, providing a testbed for differentiable physics and AI integration that enables end-to-end gradient-based optimization. The code solves the Vlasov-Maxwell system on a staggered Yee lattice with either periodic, reflective, or absorbing boundary conditions, allowing both an explicit Boris solver and an implicit Crank-Nicolson method via Picard iteration to ensure energy conservation. Here, we detail the numerical methods employed, validate against standard benchmarks, and showcase the use of its auto-differentiation capabilities.

JAX-in-Cell: A Differentiable Particle-in-Cell Code for Plasma Physics Applications

TL;DR

JAX-in-Cell presents a differentiable, fully electromagnetic 1D3V PIC implemented in the JAX framework, unifying explicit Boris push and implicit Crank-Nicolson schemes on CPU/GPU/TPU with a Python-based workflow. It solves the system on a staggered Yee lattice, enforces charge conservation via divergence cleaning, and enables end-to-end gradients through automatic differentiation for optimization tasks. The authors validate the code against Landau damping, the two-stream instability, the Weibel instability, and bump-on-tail dynamics, demonstrating accurate growth rates and energy conservation, and showcase autodiff-enabled optimization of beam parameters with six iterations. The work offers a practical, open-source platform that bridges educational scripts and production PIC codes, enabling gradient-based plasma optimization and AI integration with substantial hardware-accelerated performance.

Abstract

JAX-in-Cell is a fully electromagnetic, multispecies, and relativistic 1D3V Particle-in-Cell (PIC) framework implemented entirely in JAX. It provides a modern, Python-based alternative to traditional PIC frameworks. It leverages Just-In-Time compilation and automatic vectorization to achieve the performance of traditional compiled codes on CPUs, GPUs, and TPUs. The resulting framework bridges the gap between educational scripts and production codes, providing a testbed for differentiable physics and AI integration that enables end-to-end gradient-based optimization. The code solves the Vlasov-Maxwell system on a staggered Yee lattice with either periodic, reflective, or absorbing boundary conditions, allowing both an explicit Boris solver and an implicit Crank-Nicolson method via Picard iteration to ensure energy conservation. Here, we detail the numerical methods employed, validate against standard benchmarks, and showcase the use of its auto-differentiation capabilities.

Paper Structure

This paper contains 4 sections, 6 equations, 6 figures.

Figures (6)

  • Figure 1: Time-stepping algorithms in JAX-in-Cell. Left: explicit Boris time-stepper and a Finite-Difference Time-Domain (FDTD) method using a staggered Yee grid for the electromagnetic fields. Right: implicit Crank-Nicolson time stepper using a Picard iteration for the electromagnetic system.
  • Figure 2: Electric field energy evolution for Landau damping and the two-stream instability. (a) Landau damping with analytical damping rate $\gamma = 0.153\omega_{pe}$. (b) Two-stream instability showing fitted exponential growth rate. (c--d) Relative total energy deviation $|E_{\text{total}} - E_{\text{total}}(0)| / E_{\text{total}}(0)$ demonstrating energy conservation.
  • Figure 3: Weibel instability. (a) Evolution of the magnetic field energy and the relative energy error of the simulation during the Weibel instability. (b) Spatial profile of the magnetic field $B_y$.
  • Figure 4: Simulation of the bump-on-tail instability. The number of pseudo-particles in the bulk and beam populations is equal, with a beam-to-bulk weight ratio of $3\times10^{-2}$. (a) Time evolution of the electric field energy and relative energy error. (b) Snapshot of phase space at $80\,\omega_{pe}^{-1}$.
  • Figure 5: (a) Comparison of total runtime between CPU and GPU. (b) Influence of pseudo-particle number on the two-stream instability sampling results. Growth rates computed from exponential fits.
  • ...and 1 more figures