Table of Contents
Fetching ...

Composable Effects for Flexible and Accelerated Probabilistic Programming in NumPyro

Du Phan, Neeraj Pradhan, Martin Jankowiak

TL;DR

The paper tackles the need for fast, flexible probabilistic programming by uniting Pyro-like modeling with JAX-based acceleration. It introduces NumPyro, a NumPy-based PPL that uses effect handlers to compose with JAX transformations, enabling end-to-end JIT compilation and vectorization. A key contribution is an iterative NUTS implementation that is fully JIT-compiled and memory-efficient, along with vmap-enabled batching for inference tasks. Empirical results show substantial speedups over Pyro and Stan across small and large datasets, underscoring the practical impact for scalable probabilistic inference.

Abstract

NumPyro is a lightweight library that provides an alternate NumPy backend to the Pyro probabilistic programming language with the same modeling interface, language primitives and effect handling abstractions. Effect handlers allow Pyro's modeling API to be extended to NumPyro despite its being built atop a fundamentally different JAX-based functional backend. In this work, we demonstrate the power of composing Pyro's effect handlers with the program transformations that enable hardware acceleration, automatic differentiation, and vectorization in JAX. In particular, NumPyro provides an iterative formulation of the No-U-Turn Sampler (NUTS) that can be end-to-end JIT compiled, yielding an implementation that is much faster than existing alternatives in both the small and large dataset regimes.

Composable Effects for Flexible and Accelerated Probabilistic Programming in NumPyro

TL;DR

The paper tackles the need for fast, flexible probabilistic programming by uniting Pyro-like modeling with JAX-based acceleration. It introduces NumPyro, a NumPy-based PPL that uses effect handlers to compose with JAX transformations, enabling end-to-end JIT compilation and vectorization. A key contribution is an iterative NUTS implementation that is fully JIT-compiled and memory-efficient, along with vmap-enabled batching for inference tasks. Empirical results show substantial speedups over Pyro and Stan across small and large datasets, underscoring the practical impact for scalable probabilistic inference.

Abstract

NumPyro is a lightweight library that provides an alternate NumPy backend to the Pyro probabilistic programming language with the same modeling interface, language primitives and effect handling abstractions. Effect handlers allow Pyro's modeling API to be extended to NumPyro despite its being built atop a fundamentally different JAX-based functional backend. In this work, we demonstrate the power of composing Pyro's effect handlers with the program transformations that enable hardware acceleration, automatic differentiation, and vectorization in JAX. In particular, NumPyro provides an iterative formulation of the No-U-Turn Sampler (NUTS) that can be end-to-end JIT compiled, yielding an implementation that is much faster than existing alternatives in both the small and large dataset regimes.

Paper Structure

This paper contains 20 sections, 4 figures, 1 table, 2 algorithms.

Figures (4)

  • Figure 1: A simple logistic regression model. The modeling language is the same as in Pyro.
  • Figure 2: Empirical evaluation of time taken by NumPyro's Iterative NUTS with respect to other frameworks.
  • Figure 3: A graphical representation of how binary trees are constructed in IterativeBuildTree. The orange node is the leaf generated at the current step. Blue nodes are the leaves stored in memory for the purpose of checking the U-Turn condition. White nodes are past leaves that have been removed from memory. Dashed white nodes have not been generated yet. Thick black lines link the left and right leaves of subtrees where we need to check the U-Turn condition.
  • Figure 4: BuildTree