Table of Contents
Fetching ...

Probabilistic Programming with Programmable Variational Inference

McCoy R. Becker, Alexander K. Lew, Xiaoyan Wang, Matin Ghavami, Mathieu Huot, Martin C. Rinard, Vikash K. Mansinghka

TL;DR

This paper introduces a modular framework for probabilistic programming-based variational inference, expressing models, variational families, and objectives as programmable constructs and transforming them into unbiased gradient estimators via compositional program transformations. It provides a formal denotational and logical-relations foundation for correctness, and implements genjax.vi to demonstrate performance comparable to established PPLs while enabling a broader, programmable space of VI objectives and gradient estimators. The approach addresses automation gaps, reduces engineering duplication, and improves reasoning about gradient correctness by isolating tracing, density evaluation, and differentiation concerns. The result is a scalable, extensible system that can accommodate advanced VI objectives (e.g., IWELBO, IWHVI) and expressive variational families, with empirical evidence of competitive speed and improved expressivity on deep generative tasks.

Abstract

Compared to the wide array of advanced Monte Carlo methods supported by modern probabilistic programming languages (PPLs), PPL support for variational inference (VI) is less developed: users are typically limited to a predefined selection of variational objectives and gradient estimators, which are implemented monolithically (and without formal correctness arguments) in PPL backends. In this paper, we propose a more modular approach to supporting variational inference in PPLs, based on compositional program transformation. In our approach, variational objectives are expressed as programs, that may employ first-class constructs for computing densities of and expected values under user-defined models and variational families. We then transform these programs systematically into unbiased gradient estimators for optimizing the objectives they define. Our design enables modular reasoning about many interacting concerns, including automatic differentiation, density accumulation, tracing, and the application of unbiased gradient estimation strategies. Additionally, relative to existing support for VI in PPLs, our design increases expressiveness along three axes: (1) it supports an open-ended set of user-defined variational objectives, rather than a fixed menu of options; (2) it supports a combinatorial space of gradient estimation strategies, many not automated by today's PPLs; and (3) it supports a broader class of models and variational families, because it supports constructs for approximate marginalization and normalization (previously introduced only for Monte Carlo inference). We implement our approach in an extension to the Gen probabilistic programming system (genjax.vi, implemented in JAX), and evaluate on several deep generative modeling tasks, showing minimal performance overhead vs. hand-coded implementations and performance competitive with well-established open-source PPLs.

Probabilistic Programming with Programmable Variational Inference

TL;DR

This paper introduces a modular framework for probabilistic programming-based variational inference, expressing models, variational families, and objectives as programmable constructs and transforming them into unbiased gradient estimators via compositional program transformations. It provides a formal denotational and logical-relations foundation for correctness, and implements genjax.vi to demonstrate performance comparable to established PPLs while enabling a broader, programmable space of VI objectives and gradient estimators. The approach addresses automation gaps, reduces engineering duplication, and improves reasoning about gradient correctness by isolating tracing, density evaluation, and differentiation concerns. The result is a scalable, extensible system that can accommodate advanced VI objectives (e.g., IWELBO, IWHVI) and expressive variational families, with empirical evidence of competitive speed and improved expressivity on deep generative tasks.

Abstract

Compared to the wide array of advanced Monte Carlo methods supported by modern probabilistic programming languages (PPLs), PPL support for variational inference (VI) is less developed: users are typically limited to a predefined selection of variational objectives and gradient estimators, which are implemented monolithically (and without formal correctness arguments) in PPL backends. In this paper, we propose a more modular approach to supporting variational inference in PPLs, based on compositional program transformation. In our approach, variational objectives are expressed as programs, that may employ first-class constructs for computing densities of and expected values under user-defined models and variational families. We then transform these programs systematically into unbiased gradient estimators for optimizing the objectives they define. Our design enables modular reasoning about many interacting concerns, including automatic differentiation, density accumulation, tracing, and the application of unbiased gradient estimation strategies. Additionally, relative to existing support for VI in PPLs, our design increases expressiveness along three axes: (1) it supports an open-ended set of user-defined variational objectives, rather than a fixed menu of options; (2) it supports a combinatorial space of gradient estimation strategies, many not automated by today's PPLs; and (3) it supports a broader class of models and variational families, because it supports constructs for approximate marginalization and normalization (previously introduced only for Monte Carlo inference). We implement our approach in an extension to the Gen probabilistic programming system (genjax.vi, implemented in JAX), and evaluate on several deep generative modeling tasks, showing minimal performance overhead vs. hand-coded implementations and performance competitive with well-established open-source PPLs.
Paper Structure (57 sections, 7 theorems, 16 equations, 20 figures, 3 tables)

This paper contains 57 sections, 7 theorems, 16 equations, 20 figures, 3 tables.

Key Result

lemma 1

Let $\Gamma \vdash t : \tau$ be an open term of $\lambda_{\text{Gen}}$. Then $\xi\{\Gamma\} \vdash \xi\{t\} : \xi\{\tau\}$ is a well-typed open term of $\lambda_{\text{ADEV}}$, and $\forall (\gamma, \gamma') \in R^\xi_\Gamma, (\llbracket \tau\rrbracket(\gamma), \llbracket \xi\{\tau\}\rrbracket(\gamm

Figures (20)

  • Figure 1: We compose multiple program transformations to automate the construction of unbiased gradient estimators for variational inference. The user begins by writing programs in the generative language (gray) to encode a model and variational family. These programs are compiled into procedures for density evaluation and simulation in the differentiable language (yellow). These automated procedures can be used to concisely define a variational objective. We can then apply the ADEV differentiation algorithm to automatically construct a gradient estimator, which unbiasedly estimates gradients of the variational objective. Solid outlines indicate user-written programs, whereas dashed outlines indicate automatically constructed programs.
  • Figure 2: An illustration of our modular approach to automating variational inference, on a toy example. (Top) Users write generative code to define a model and a variational family. Automated program transformations, formalized and proven correct in §\ref{['sec:gen-transformations']}, compile differentiable code for evaluating densities and simulating traces. (Middle) Users write a program in the differentiable language to define a variational objective, in this case the evidence lower bound (ELBO). This code may invoke compiled simulators and density evaluators for generative programs. The adev transformation automates an unbiased gradient estimator for the objective. (Bottom) The gradients are used for optimization, training the variational family to approximate the posterior.
  • Figure 2: Time (in seconds) to train the AIR model eslami_attend_2016 for one epoch (batch size 64) with different objectives and estimators. All discrete variables use the same estimation strategy. IWELBO runs have $n = 2$ particles.
  • Figure 3: With programmable VI, users can define their own variational objectives, and use new modeling language features to program more expressive models and variational families. Here, we apply importance-weighted VI burda_importance_2016 and hierarchical VI ranganath_hierarchical_2016sobolev_importance_2019 to the toy problem from Fig. \ref{['fig:example_transcript']}.
  • Figure 4: Grammars, selected typing rules, and selected denotations for our core languages $\lambda_{\text{Gen}}$ and $\lambda_{ADEV}$.
  • ...and 15 more figures

Theorems & Definitions (8)

  • lemma 1
  • theorem 1
  • lemma 2
  • theorem 2
  • definition 1: locally dominated
  • theorem 3
  • theorem 4
  • lemma 3