Table of Contents
Fetching ...

PyBird-JAX: Accelerated inference in large-scale structure with model-independent emulation of one-loop galaxy power spectra

Alexander Reeves, Pierre Zhang, Henry Zheng

TL;DR

PyBird-JAX, a differentiable, JAX-based implementation of PyBird, using internal neural network emulators to accelerate computationally costly operations for rapid large-scale structure (LSS) analysis, opening the door to accelerated cosmological inference with minimal accuracy loss and no pretraining.

Abstract

We present $\texttt{PyBird-JAX}$, a differentiable, $\texttt{JAX}$-based implementation of $\texttt{PyBird}$, using internal neural network emulators to accelerate computationally costly operations for rapid large-scale structure (LSS) analysis. $\texttt{PyBird-JAX}$ computes one-loop EFTofLSS predictions for redshift-space galaxy power spectrum multipoles in 1.2 ms on a CPU and 0.2 ms on a GPU, achieving 3-4 orders of magnitude speed-up over $\texttt{PyBird}$. The emulators take a compact spline-based representation of the input linear power spectrum $P(k)$ as feature vectors, making the approach applicable to a wide range of cosmological models. We rigorously validate its accuracy against large-volume simulations and on BOSS data, including cosmologies not explicitly represented in the training set. Leveraging automatic differentiation, $\texttt{PyBird-JAX}$ supports Fisher forecasting, Taylor expansion of model predictions, gradient-based searches, and vectorised ensemble sampling. Interfaced with a variety of samplers and Boltzmann solvers, $\texttt{PyBird-JAX}$ provides a high-performance, end-to-end inference pipeline. Combined with a symbolic-$P(k)$ generator, a typical Stage-4 LSS MCMC converges in minutes on a GPU. Our results demonstrate that $\texttt{PyBird-JAX}$ delivers the precision and speed required for upcoming LSS surveys, opening the door to accelerated cosmological inference with minimal accuracy loss and no pretraining. In a companion paper [1], we put $\texttt{PyBird-JAX}$ to use in achieving LSS marginalised constraints free from volume projection effects through non-flat measures.

PyBird-JAX: Accelerated inference in large-scale structure with model-independent emulation of one-loop galaxy power spectra

TL;DR

PyBird-JAX, a differentiable, JAX-based implementation of PyBird, using internal neural network emulators to accelerate computationally costly operations for rapid large-scale structure (LSS) analysis, opening the door to accelerated cosmological inference with minimal accuracy loss and no pretraining.

Abstract

We present , a differentiable, -based implementation of , using internal neural network emulators to accelerate computationally costly operations for rapid large-scale structure (LSS) analysis. computes one-loop EFTofLSS predictions for redshift-space galaxy power spectrum multipoles in 1.2 ms on a CPU and 0.2 ms on a GPU, achieving 3-4 orders of magnitude speed-up over . The emulators take a compact spline-based representation of the input linear power spectrum as feature vectors, making the approach applicable to a wide range of cosmological models. We rigorously validate its accuracy against large-volume simulations and on BOSS data, including cosmologies not explicitly represented in the training set. Leveraging automatic differentiation, supports Fisher forecasting, Taylor expansion of model predictions, gradient-based searches, and vectorised ensemble sampling. Interfaced with a variety of samplers and Boltzmann solvers, provides a high-performance, end-to-end inference pipeline. Combined with a symbolic- generator, a typical Stage-4 LSS MCMC converges in minutes on a GPU. Our results demonstrate that delivers the precision and speed required for upcoming LSS surveys, opening the door to accelerated cosmological inference with minimal accuracy loss and no pretraining. In a companion paper [1], we put to use in achieving LSS marginalised constraints free from volume projection effects through non-flat measures.

Paper Structure

This paper contains 37 sections, 19 equations, 9 figures, 4 tables.

Figures (9)

  • Figure 1: Spline decomposition accuracy --- Distribution of errors in the galaxy power spectrum multipoles computed with PyBird, comparing results obtained using either the input linear matter power spectrum or its spline-reconstructed counterpart, across the CosmoRef testing bank described in table \ref{['tab:validation_range']}. Errors are shown relative to representative uncertainties expected for Stage-4 LSS surveys.
  • Figure 2: Representative coverage over emulator input parameter space --- Inflated Gaussian copula distribution over the emulator training input space, shown for selected knots (all ks in $\,h\, {\rm Mpc}^{-1}\,$), maximal power amplitude $A = P_{\mathrm{max}}$ (in $(\, {\rm Mpc} \,h^{-1}\,)^3$), and growth factor $f$. The original reference CosmoRef bank and samples from BOSS CMASS $\Lambda$CDM analysis are also shown for comparison.
  • Figure 3: Emulator accuracy --- Left panel: 68% and 95% quantiles of the differences in the galaxy power spectrum multipoles across scales, computed with PyBird. The comparison is between the NN-based emulator predictions and the full calculations, evaluated over the independent validation set described in sec. \ref{['sec:performance']}. Errors are shown relative to representative uncertainties expected for Stage-4 LSS surveys as described in the text. Right panel: Cumulative histogram of the maximum absolute differences across the full range of scales, with the two vertical dashed lines indicating the 68% and 95% quantiles of the validation samples.
  • Figure 4: Comparison PyBird vs. PyBird-Emu on PT simulation data --- 1D and 2D marginal posterior distributions of inferred $\Lambda$CDM parameters from the PT challenge simulations, with fixed $\omega_{\rm b}$ and $n_s$. For all configurations, all parameters are recovered within $\sim 1\sigma$, with the truth shown in dashed line. The posteriors from PyBird (blue contours) and PyBird-Emu (yellow contours) agree at subpercent level ($<0.3\%$ of the parameter values on 1D marginals) whether fit in multipoles $P_\ell$ up to $k^\ell_{\rm max} = 0.14 \,h\, {\rm Mpc}^{-1}\,$ (left panel) or in wedges $\slashed{P} + w_\ell$ with $\slashed{P}$ analysed up to $k^\slashed{P}_{\rm max} = 0.3 \,h\, {\rm Mpc}^{-1}\,$ (right panel). This validates the raw accuracy of PyBird-Emu in recovering cosmological parameters to high precision.
  • Figure 5: Comparison (I) PyBird vs. PyBird-Emu on BOSS data --- 1D and 2D marginal posterior distributions of inferred parameters in $\Lambda$CDM (left panel) or $w_0w_a$CDM (right panel) from BOSS data, with a BBN prior on $\omega_{\rm b}$ (not shown for clarity). The $w_0w_a$CDM fit also includes low-redshift supernova data from Pantheon+. Given the large parameter uncertainties and non-negligible shifts in the means with respect to the central values used for the emulator validation shown in table \ref{['tab:validation_range']}, these analysis setups stand as stringent parameter coverage tests for PyBird-Emu. The posteriors of PyBird (blue contours) and PyBird-Emu (yellow contours) are practically indistinguishable for both scenarios, validating the ability of PyBird-Emu in recovering cosmological parameters across wide parameter ranges.
  • ...and 4 more figures