Table of Contents
Fetching ...

Simulation-based inference with the Python Package sbijax

Simon Dirmeier, Simone Ulzega, Antonietta Mira, Carlo Albert

Abstract

Neural simulation-based inference (SBI) describes an emerging family of methods for Bayesian inference with intractable likelihood functions that use neural networks as surrogate models. Here we introduce sbijax, a Python package that implements a wide variety of state-of-the-art methods in neural simulation-based inference using a user-friendly programming interface. sbijax offers high-level functionality to quickly construct SBI estimators, and compute and visualize posterior distributions with only a few lines of code. In addition, the package provides functionality for conventional approximate Bayesian computation, to compute model diagnostics, and to automatically estimate summary statistics. By virtue of being entirely written in JAX, sbijax is extremely computationally efficient, allowing rapid training of neural networks and executing code automatically in parallel on both CPU and GPU.

Simulation-based inference with the Python Package sbijax

Abstract

Neural simulation-based inference (SBI) describes an emerging family of methods for Bayesian inference with intractable likelihood functions that use neural networks as surrogate models. Here we introduce sbijax, a Python package that implements a wide variety of state-of-the-art methods in neural simulation-based inference using a user-friendly programming interface. sbijax offers high-level functionality to quickly construct SBI estimators, and compute and visualize posterior distributions with only a few lines of code. In addition, the package provides functionality for conventional approximate Bayesian computation, to compute model diagnostics, and to automatically estimate summary statistics. By virtue of being entirely written in JAX, sbijax is extremely computationally efficient, allowing rapid training of neural networks and executing code automatically in parallel on both CPU and GPU.
Paper Structure (33 sections, 37 equations, 7 figures, 3 tables, 2 algorithms)

This paper contains 33 sections, 37 equations, 7 figures, 3 tables, 2 algorithms.

Figures (7)

  • Figure 1: Marginal posterior density and trace plots. Each variable is visualized separately. Multivariate parameter like the mean are shown in different colors in a panel. The titles of the figures respect the variable names of the generative model (see section \ref{['sec:sbijax-model_definition']}).
  • Figure 2: Split-$\hat{R}$, rank and effective sample size plots.
  • Figure 3: Posterior pair plots and marginal distributions. For this benchmark model (SLCP), SNLE achieves the best approximation to the true posterior (when comparing to posterior distribution inferred using the slice sampler). SMC-ABC+NASS and FMPE show worse performance.
  • Figure 4: MCMC model diagnostics for SLCP model. We show three common MCMC model diagnostics for which sbijax offers functionality for visualization. The left column shows posterior traces, i.e., the values of $\theta$ for each iteration and for each chain (different colors). The column in the middle shows rank statistics for each parameter and chain (different colors). The right column shows the bulk and tail effective sample sizes
  • Figure 5: Training and validation loss for bivariate Gaussian example. The training of the neural network converged in this example and was stopped early after roughly $130$ episodes, because there were only insignificant improvements on the validation set.
  • ...and 2 more figures