Differentiable Cosmological Hydrodynamics for Field-Level Inference and High Dimensional Parameter Constraints
Benjamin Horowitz, Zarija Lukic
TL;DR
diffHydro presents a fully differentiable cosmological hydrodynamics pipeline that integrates a TVD Euler solver with a PM gravity solver and differentiable subgrid processes via Gumbel-Softmax, enabling gradient-based and Bayesian inference in high-dimensional parameter spaces. The approach is implemented in JAX, supports GPU acceleration, and is validated through Sedov blast tests and synthetic field-level inferences. Key contributions include differentiating through stochastic discrete events, enabling joint cosmology–subgrid parameter inference from baryon statistics and reconstructing initial conditions from noisy data. As a proof-of-principle, the work demonstrates the feasibility of end-to-end differentiable forward modeling for large-scale structure, with clear paths toward scalability and richer physics on future hardware.
Abstract
Hydrodynamical simulations are the most accurate way to model structure formation in the universe, but they often involve a large number of astrophysical parameters modeling subgrid physics, in addition to cosmological parameters. This results in a high-dimensional space that is difficult to jointly constrain using traditional statistical methods due to prohibitive computational costs. To address this, we present a fully differentiable approach for cosmological hydrodynamical simulations and a proof-of-concept implementation, diffhydro. By back-propagating through an upwind finite volume scheme for solving the Euler Equations jointly with a dark matter particle-mesh method for Poisson equation, we are able to efficiently evaluate derivatives of the output baryonic fields with respect to input density and model parameters. Importantly, we demonstrate how to differentiate through stochastically sampled discrete random variables, which frequently appear in subgrid models. We use this framework to rapidly sample sub-grid physics and cosmological parameters as well as perform field level inference of initial conditions using high dimensional optimization techniques. Our code is implemented in JAX (python), allowing easy code development and GPU acceleration.
