Table of Contents
Fetching ...

Fast and Robust Simulation-Based Inference With Optimization Monte Carlo

Vasilis Gkolemis, Christos Diou, Michael Gutmann

TL;DR

R2OMC addresses the difficulty of likelihood-free Bayesian inference for complex differentiable simulators by reframing stochastic simulation as deterministic optimization and guiding posterior sampling with gradient information. It extends the ROMC framework with (i) a gradient-based mechanism to identify and filter distractor outputs, (ii) per-data-point optimization to generate informative proposal regions, and (iii) a gradient-enabled adaptive importance sampling scheme that combines multiple iid observations. The method is implemented in JAX to exploit vectorization and automatic differentiation, yielding substantial runtime reductions while maintaining or improving posterior accuracy across high-dimensional, distractor-rich, and multimodal problems. Empirical results across MoG, SBI benchmarks, and image-based tasks show R2OMC often matching or surpassing state-of-the-art neural SBI methods at a fraction of the computational cost, highlighting its practical impact for fast, robust SBI in challenging settings.

Abstract

Bayesian parameter inference for complex stochastic simulators is challenging due to intractable likelihood functions. Existing simulation-based inference methods often require large number of simulations and become costly to use in high-dimensional parameter spaces or in problems with partially uninformative outputs. We propose a new method for differentiable simulators that delivers accurate posterior inference with substantially reduced runtimes. Building on the Optimization Monte Carlo framework, our approach reformulates stochastic simulation as deterministic optimization problems. Gradient-based methods are then applied to efficiently navigate toward high-density posterior regions and avoid wasteful simulations in low-probability areas. A JAX-based implementation further enhances the performance through vectorization of key method components. Extensive experiments, including high-dimensional parameter spaces, uninformative outputs, multiple observations and multimodal posteriors show that our method consistently matches, and often exceeds, the accuracy of state-of-the-art approaches, while reducing the runtime by a substantial margin.

Fast and Robust Simulation-Based Inference With Optimization Monte Carlo

TL;DR

R2OMC addresses the difficulty of likelihood-free Bayesian inference for complex differentiable simulators by reframing stochastic simulation as deterministic optimization and guiding posterior sampling with gradient information. It extends the ROMC framework with (i) a gradient-based mechanism to identify and filter distractor outputs, (ii) per-data-point optimization to generate informative proposal regions, and (iii) a gradient-enabled adaptive importance sampling scheme that combines multiple iid observations. The method is implemented in JAX to exploit vectorization and automatic differentiation, yielding substantial runtime reductions while maintaining or improving posterior accuracy across high-dimensional, distractor-rich, and multimodal problems. Empirical results across MoG, SBI benchmarks, and image-based tasks show R2OMC often matching or surpassing state-of-the-art neural SBI methods at a fraction of the computational cost, highlighting its practical impact for fast, robust SBI in challenging settings.

Abstract

Bayesian parameter inference for complex stochastic simulators is challenging due to intractable likelihood functions. Existing simulation-based inference methods often require large number of simulations and become costly to use in high-dimensional parameter spaces or in problems with partially uninformative outputs. We propose a new method for differentiable simulators that delivers accurate posterior inference with substantially reduced runtimes. Building on the Optimization Monte Carlo framework, our approach reformulates stochastic simulation as deterministic optimization problems. Gradient-based methods are then applied to efficiently navigate toward high-density posterior regions and avoid wasteful simulations in low-probability areas. A JAX-based implementation further enhances the performance through vectorization of key method components. Extensive experiments, including high-dimensional parameter spaces, uninformative outputs, multiple observations and multimodal posteriors show that our method consistently matches, and often exceeds, the accuracy of state-of-the-art approaches, while reducing the runtime by a substantial margin.

Paper Structure

This paper contains 46 sections, 19 equations, 15 figures, 3 tables, 2 algorithms.

Figures (15)

  • Figure 1: Posterior inference on a two-Gaussian mixture simulator. Rows correspond to simple ($D=D_y=2$), distractor ($D=2, D_y=18$), and high-dimensional ($D=D_y=10$) settings. Columns show the reference posterior, and samples obtained from R2OMC (proposed), and three established neural-based methods—Neural Posterior Estimator greenberg2019automatic, BayesFlow radev2020bayesflow, and Flow Matching Posterior Estimator wildberger2023flow—each evaluated at low and high runtime. Unlike the neural methods, which require high runtime in the distractor and high-dimensional settings, R2OMC produces accurate posterior samples with low runtime. See Appendix \ref{['app-sec:intro-example']} for details on the experimental setup.
  • Figure 2: R2OMC overview: Distractors are automatically filtered out and proposal distributions $q_i^n$ are created for each observation (here two): the blue-dotted region indicates where the proposal distributions for the first observation are nonzero, and the green-dotted region indicates the same for the second observation. Only samples that fall within the intersection of both regions are given a positive weight and become posterior samples.
  • Figure 3: MoG benchmark: Success Frontier plots showing the minimum runtime required to achieve a mean C2ST score $\leq 0.75$ across varying $D$. Corresponding Budget vs. Dimension plots are provided in Appendix \ref{['app-sec:experiments']}.
  • Figure 4: SLCP pairwise posteriors for multiple observations. The four left panels show proposal samples per observation within $\epsilon$-distance, while the rightmost panel shows the subset of samples selected as closest across all observations using the R2OMC weighting scheme (Eq. \ref{['eq:rromc_weight']}).
  • Figure 5: SLCP (left) and SLCP with distractors (right): C2ST score vs. runtime.
  • ...and 10 more figures