Table of Contents
Fetching ...

microJAX: A Differentiable Framework for Microlensing Modeling with GPU-Accelerated Image-Centered Ray Shooting

Shota Miyazaki, Hajime Kawahara

TL;DR

microJAX delivers the first fully differentiable GPU-accelerated ICRS framework for microlensing, enabling gradient-based inference over complex multi-lens configurations and extended sources. By employing the Ehrlich-Aberth root-finder, a static, GPU-friendly image-centered grid construction, and carefully differentiable integration rules, it achieves accurate magnifications with differentiable dependencies on all parameters. The approach yields substantial speedups in the small-source, limb-darkened regime and integrates naturally with probabilistic programming for HMC and variational inference, facilitating scalable analyses for upcoming surveys like the Roman Space Telescope. Overall, microJAX establishes a robust foundation for precise, large-scale microlensing inferences within modern inference pipelines.

Abstract

We introduce microJAX, the first fully differentiable implementation of the image-centered ray-shooting (ICRS) algorithm for gravitational microlensing. Built on JAX and its XLA just-in-time compiler, microJAX exploits GPU parallelism while providing exact gradients through automatic differentiation. The current release supports binary- and triple-lens geometries, including limb-darkened extended-source effects, and delivers magnifications that remain differentiable for all model parameters. Benchmarks show that microJAX matches the accuracy of established packages and attains up to a factor of $\sim$5-6 speed-up in the small-source, limb-darkened regime on an NVIDIA A100 GPU. Since the model is fully differentiable, it integrates seamlessly with probabilistic programming frameworks, enabling scalable Hamiltonian Monte Carlo and variational inference workflows. Although the present work focuses on standard microlensing magnification models, the modular architecture is designed to support upcoming implementations of microlensing higher-order effects, while remaining compatible with external likelihood frameworks that incorporate advanced noise models. microJAX thus provides a robust foundation for precise and large-scale surveys anticipated in the coming decade, including the Nancy Grace Roman Space Telescope, where scalable, physically self-consistent inference will be essential for maximizing scientific return.

microJAX: A Differentiable Framework for Microlensing Modeling with GPU-Accelerated Image-Centered Ray Shooting

TL;DR

microJAX delivers the first fully differentiable GPU-accelerated ICRS framework for microlensing, enabling gradient-based inference over complex multi-lens configurations and extended sources. By employing the Ehrlich-Aberth root-finder, a static, GPU-friendly image-centered grid construction, and carefully differentiable integration rules, it achieves accurate magnifications with differentiable dependencies on all parameters. The approach yields substantial speedups in the small-source, limb-darkened regime and integrates naturally with probabilistic programming for HMC and variational inference, facilitating scalable analyses for upcoming surveys like the Roman Space Telescope. Overall, microJAX establishes a robust foundation for precise, large-scale microlensing inferences within modern inference pipelines.

Abstract

We introduce microJAX, the first fully differentiable implementation of the image-centered ray-shooting (ICRS) algorithm for gravitational microlensing. Built on JAX and its XLA just-in-time compiler, microJAX exploits GPU parallelism while providing exact gradients through automatic differentiation. The current release supports binary- and triple-lens geometries, including limb-darkened extended-source effects, and delivers magnifications that remain differentiable for all model parameters. Benchmarks show that microJAX matches the accuracy of established packages and attains up to a factor of 5-6 speed-up in the small-source, limb-darkened regime on an NVIDIA A100 GPU. Since the model is fully differentiable, it integrates seamlessly with probabilistic programming frameworks, enabling scalable Hamiltonian Monte Carlo and variational inference workflows. Although the present work focuses on standard microlensing magnification models, the modular architecture is designed to support upcoming implementations of microlensing higher-order effects, while remaining compatible with external likelihood frameworks that incorporate advanced noise models. microJAX thus provides a robust foundation for precise and large-scale surveys anticipated in the coming decade, including the Nancy Grace Roman Space Telescope, where scalable, physically self-consistent inference will be essential for maximizing scientific return.

Paper Structure

This paper contains 13 sections, 13 equations, 6 figures.

Figures (6)

  • Figure 1: An example of a triple-lens configuration with parameters $(q, s, q_3, r_3, \psi, \rho) = (0.5, 1.1, 0.91, 1.46, 0.1)$ and source center $\bm{w}_{\rm center} = (0, 0.1)$. The red and green curves represent the caustics and critical curves, respectively, and the black dots indicate the lens positions. The source limb is sampled with $N_{\rm limb} = 500$ points (blue), which are mapped to the image plane, producing six distinct images (purple). For each image, a polar-coordinate region is defined based on the image-limb points, within which a uniform $50 \times 50$ grid is placed for integration. Grid points falling inside the source are marked in orange, while those outside are shown in gray.
  • Figure 2: The top panel shows the triple-lens magnification curve for a uniform brightness source, with the inset displaying the corresponding caustic structure (red solid line) and source trajectory (blue solid line), where the microlensing parameters are $(t_0, t_{\rm E}, u_0, q, s, \alpha, q_3, r_3, \psi)=(0, 10\;{\rm day}, 0, 0.5, 1.1, 60^\circ, 0.03, 1.24, 76.0^\circ)$. The subsequent panels show the gradients of the magnification for each microlensing parameter. The resolutions of the inverse-ray grid are $(N_r, N_\theta)=(500, 500)$, respectively.
  • Figure 3: Geometrical configuration of the triple-lens system used in magnification and gradient calculations. The primary and secondary lenses ($M_1$, $M_2$) are separated by $s \equiv 2a$ along the horizontal axis, with the origin $\mathcal{O}$ defined as their barycenter. The source (blue line) moves with impact parameter $u_0$ and angle $\alpha$ relative to the binary axis. The third lens $M_3$ is placed at polar coordinates $(r_3, \psi)$ relative to the midpoint between $M_1$ and $M_2$.
  • Figure 4: Relative accuracy of microJAX magnification estimates compared to VBBinaryLensing, as a function of source size ($\rho=0.1$ to $10^{-4}$) and angular resolution ($N_\theta = 500, 1000, 8000$) with fixed $N_r=1000$. Each panel shows the relative error in magnification computed by microJAX, to the reference values from VBBinaryLensing, for 1000 randomly sampled source positions near the central caustic of a binary lens with $q=0.1$ and $s=1$. A uniform-brightness source is assumed, and VBBinaryLensing is run at a relative accuracy of $10^{-5}$. The lower-right panel illustrates the spatial distribution of test points near the caustic.
  • Figure 5: An example of a lensing geometry, similar with Figure \ref{['fig:example_grid_allocating']}, where the error in magnification increases, with parameters $(q, s, \rho) = (0.1, 1.1, 5\times10^{-4})$ and $\bm{w}_{\rm center}=(-0.0765, 0)$. Each inset panel shows a zoomed-in view of each position. In this case, the angular resolution of the grid is too sparse to accurately determine the image boundary compared to the source size, illustrated in the lower-right panel.
  • ...and 1 more figures