Table of Contents
Fetching ...

Scalable Simulation-Based Model Inference with Test-Time Complexity Control

Manuel Gloeckler, J. P. Manzano-Patrón, Stamatios N. Sotiropoulos, Cornelius Schröder, Jakob H. Macke

Abstract

Simulation plays a central role in scientific discovery. In many applications, the bottleneck is no longer running a simulator; it is choosing among large families of plausible simulators, each corresponding to different forward models/hypotheses consistent with observations. Over large model families, classical Bayesian workflows for model selection are impractical. Furthermore, amortized model selection methods typically hard-code a fixed model prior or complexity penalty at training time, requiring users to commit to a particular parsimony assumption before seeing the data. We introduce PRISM, a simulation-based encoder-decoder that infers a joint posterior over both discrete model structures and associated continuous parameters, while enabling test-time control of model complexity via a tunable model prior that the network is conditioned on. We show that PRISM scales to families with combinatorially many (up to billions) of model instantiations on a synthetic symbolic regression task. As a scientific application, we evaluate PRISM on biophysical modeling for diffusion MRI data, showing the ability to perform model selection across several multi-compartment models, on both synthetic and in vivo neuroimaging data.

Scalable Simulation-Based Model Inference with Test-Time Complexity Control

Abstract

Simulation plays a central role in scientific discovery. In many applications, the bottleneck is no longer running a simulator; it is choosing among large families of plausible simulators, each corresponding to different forward models/hypotheses consistent with observations. Over large model families, classical Bayesian workflows for model selection are impractical. Furthermore, amortized model selection methods typically hard-code a fixed model prior or complexity penalty at training time, requiring users to commit to a particular parsimony assumption before seeing the data. We introduce PRISM, a simulation-based encoder-decoder that infers a joint posterior over both discrete model structures and associated continuous parameters, while enabling test-time control of model complexity via a tunable model prior that the network is conditioned on. We show that PRISM scales to families with combinatorially many (up to billions) of model instantiations on a synthetic symbolic regression task. As a scientific application, we evaluate PRISM on biophysical modeling for diffusion MRI data, showing the ability to perform model selection across several multi-compartment models, on both synthetic and in vivo neuroimaging data.
Paper Structure (46 sections, 28 equations, 20 figures, 5 tables)

This paper contains 46 sections, 28 equations, 20 figures, 5 tables.

Figures (20)

  • Figure 1: PRISM overview. (a) During training, we sample from a model family ${\mathcal{M}}$ with a hierarchical prior $p(\mathcal{M}\mid\lambda)$, where $\lambda$ controls a chosen notion of model complexity and approximate jointly the model posterior and the conditional parameter posterior $p(\mathcal{M},\theta \mid {\bm{x}}, \lambda)$ by optimizing the loss $\mathcal{L}=\mathcal{L}({\color{myorange1}{{\mathcal{M}}}}, p({\mathcal{M}} \mid {\bm{x}}, {\color{myorange1}{\lambda}} )) + \mathcal{L}({\color{myblue1}{\theta_{\mathcal{M}}}} , p(\theta \mid {\color{myorange1}{{\mathcal{M}}_\lambda}}, {\bm{x}}))$, for model evaluations ${\bm{x}}\sim p_{\color{myorange1}{{\mathcal{M}}}}({\bm{x}} \mid {\color{myblue1}{\theta_{\mathcal{M}}}} )$. (b) At inference time, we set $\lambda$ to tune parsimony, and select or explore models in the combinatorial space, infer model parameters, and deploy the resulting posteriors in downstream analyses.
  • Figure 2: Illustration on symbolic regression task.Left: Ground-truth function and noisy observations (black, $x<10$), compared to posterior predictives samples for two model complexities $\lambda$ (95% credible interval and one noiseless sample each, chosen as the simplest posterior draw). Right: Median number of model components as function of $\lambda$. Below: Equations of true function (black, 2 components), and two sampled equations for low and high complexity priors (2 and 9 components).
  • Figure 3: Architecture overview. PRISM is based on two parallel transformers to infer (i) the model posterior from tokenized model masks and (ii) the parameter posterior via a diffusion $v$-prediction network ($t$ denotes diffusion time). Observations ${\bm{x}}$ enter through cross-attention; $(\lambda, t)$ are injected via adaptive layer normalization (skip connections omitted).
  • Figure 4: PRISM on symbolic regression task. (a) Model posterior across $\lambda$ and parameter posterior for the example in Fig. \ref{['fig:symbolic_illustrative']} (for $\lambda = 0.1$). (b) Comparison to schroder2024simultaneous for a fixed prior. (c) Scaling to large model spaces for a fixed computational training budget; Red line indicates regimes beyond which not all models can be sampled during training.
  • Figure 5: dMRI models.
  • ...and 15 more figures