Table of Contents
Fetching ...

Divide-and-Conquer Predictive Coding: a structured Bayesian inference algorithm

Eli Sennesh, Hao Wu, Tommaso Salvatori

Abstract

Unexpected stimuli induce "error" or "surprise" signals in the brain. The theory of predictive coding promises to explain these observations in terms of Bayesian inference by suggesting that the cortex implements variational inference in a probabilistic graphical model. However, when applied to machine learning tasks, this family of algorithms has yet to perform on par with other variational approaches in high-dimensional, structured inference problems. To address this, we introduce a novel predictive coding algorithm for structured generative models, that we call divide-and-conquer predictive coding (DCPC). DCPC differs from other formulations of predictive coding, as it respects the correlation structure of the generative model and provably performs maximum-likelihood updates of model parameters, all without sacrificing biological plausibility. Empirically, DCPC achieves better numerical performance than competing algorithms and provides accurate inference in a number of problems not previously addressed with predictive coding. We provide an open implementation of DCPC in Pyro on Github.

Divide-and-Conquer Predictive Coding: a structured Bayesian inference algorithm

Abstract

Unexpected stimuli induce "error" or "surprise" signals in the brain. The theory of predictive coding promises to explain these observations in terms of Bayesian inference by suggesting that the cortex implements variational inference in a probabilistic graphical model. However, when applied to machine learning tasks, this family of algorithms has yet to perform on par with other variational approaches in high-dimensional, structured inference problems. To address this, we introduce a novel predictive coding algorithm for structured generative models, that we call divide-and-conquer predictive coding (DCPC). DCPC differs from other formulations of predictive coding, as it respects the correlation structure of the generative model and provably performs maximum-likelihood updates of model parameters, all without sacrificing biological plausibility. Empirically, DCPC achieves better numerical performance than competing algorithms and provides accurate inference in a number of problems not previously addressed with predictive coding. We provide an open implementation of DCPC in Pyro on Github.
Paper Structure (18 sections, 8 theorems, 58 equations, 5 figures, 4 tables, 1 algorithm)

This paper contains 18 sections, 8 theorems, 58 equations, 5 figures, 4 tables, 1 algorithm.

Key Result

Theorem 1

Each DCPC coordinate update (Equation eq:coord_weight) for a latent $z \in \mathbf{z}$ samples from $z$'s complete conditional (the normalization of Equation eq:complete_conditional). Formally, for every measurable $h: \mathcal{Z} \rightarrow \mathbb{R}$, resampled expectations with respect to the D

Figures (5)

  • Figure 1: Left: Classical PC learns a mean-field approximate posterior with prediction error layers. Right: Divide-and-conquer PC approximates the joint posterior with bottom-up and recurrent errors. Where classical predictive coding has layers communicate through shared error units, divide-and-conquer predictive coding separates recurrent from "bottom-up" error pathways to target complete conditional distributions rather than posterior marginal distributions.
  • Figure 2: Hierarchical graphical model for DLGM's.
  • Figure 3: Top: images from validation sets of MNIST (left), EMNIST (middle), and Fashion MNIST (right). Bottom: reconstructions by deep latent Gaussian models trained with DCPC for MNIST (left), EMNIST (middle), and Fashion MNIST (right), averaging over $K=4$ particles. DCPC achieves quality reconstructions by inference over $\mathbf{z}$ without training an inference network.
  • Figure 4: Left: reconstructions from the CelebA validation set. Right: samples from the generative model. DCPC achieves quality reconstructions by inference over $\mathbf{z}$ with $K=16$ particles and no inference network, while the learned generative model captures variation in the data.
  • Figure 5: Divide-and-conquer predictive coding provides an algorithmic interpretation for some of the connections mapped in the canonical neocortical microcircuit Bastos2012Bastos2020Campagnola2022: prediction errors (red) arrive through ascending pathways into the central laminar layer 4, which transmits them up to layers 2/3 (green). These layers combine the incoming errors with a present posterior estimate (green L5$\rightarrow$ L2/3 connection) to generate prediction errors for the next cortical area. Eventually the updated predictions flow back down the cortical hierarchy (blue).

Theorems & Definitions (20)

  • Definition 1: Predictive Coding Algorithm
  • Theorem 1: DCPC coordinate updates sample from the true complete conditionals
  • proof
  • Theorem 2: DCPC parameter learning requires only local gradients in a factorized generative model
  • proof
  • Definition 2: Bayesian Fisher estimator Titsias2023
  • Definition 3: Predictive coding Fisher preconditioner Titsias2023
  • Definition 4: Strict proper weighting for a density
  • Proposition 1: The free energy upper-bounds the surprisal
  • proof
  • ...and 10 more