Fast and Robust Simulation-Based Inference With Optimization Monte Carlo
Vasilis Gkolemis, Christos Diou, Michael Gutmann
TL;DR
R2OMC addresses the difficulty of likelihood-free Bayesian inference for complex differentiable simulators by reframing stochastic simulation as deterministic optimization and guiding posterior sampling with gradient information. It extends the ROMC framework with (i) a gradient-based mechanism to identify and filter distractor outputs, (ii) per-data-point optimization to generate informative proposal regions, and (iii) a gradient-enabled adaptive importance sampling scheme that combines multiple iid observations. The method is implemented in JAX to exploit vectorization and automatic differentiation, yielding substantial runtime reductions while maintaining or improving posterior accuracy across high-dimensional, distractor-rich, and multimodal problems. Empirical results across MoG, SBI benchmarks, and image-based tasks show R2OMC often matching or surpassing state-of-the-art neural SBI methods at a fraction of the computational cost, highlighting its practical impact for fast, robust SBI in challenging settings.
Abstract
Bayesian parameter inference for complex stochastic simulators is challenging due to intractable likelihood functions. Existing simulation-based inference methods often require large number of simulations and become costly to use in high-dimensional parameter spaces or in problems with partially uninformative outputs. We propose a new method for differentiable simulators that delivers accurate posterior inference with substantially reduced runtimes. Building on the Optimization Monte Carlo framework, our approach reformulates stochastic simulation as deterministic optimization problems. Gradient-based methods are then applied to efficiently navigate toward high-density posterior regions and avoid wasteful simulations in low-probability areas. A JAX-based implementation further enhances the performance through vectorization of key method components. Extensive experiments, including high-dimensional parameter spaces, uninformative outputs, multiple observations and multimodal posteriors show that our method consistently matches, and often exceeds, the accuracy of state-of-the-art approaches, while reducing the runtime by a substantial margin.
