Table of Contents
Fetching ...

Brain-like Variational Inference

Hadi Vafaii, Dekel Galor, Jacob L. Yates

TL;DR

This work presents FOND, a principled framework that derives brain-like inference dynamics by performing online natural-gradient descent on variational free energy, unifying neural and machine learning perspectives on inference. It applies FOND to derive iterative VAEs, including the iP-VAE, a spiking model that uses Poisson latents and membrane-potential dynamics to perform online inference with lateral competition. Empirically, iterative VAEs demonstrate stronger reconstruction-sparsity trade-offs, learn cortex-like features, and generalize robustly to out-of-distribution data, while remaining scalable to high-dimensional color images. The results highlight both theoretical coherence with the free energy principle and practical advantages in efficiency and generalization, with promising hardware implications for neuromorphic deployment.

Abstract

Inference in both brains and machines can be formalized by optimizing a shared objective: maximizing the evidence lower bound (ELBO) in machine learning, or minimizing variational free energy (F) in neuroscience (ELBO = -F). While this equivalence suggests a unifying framework, it leaves open how inference is implemented in neural systems. Here, we introduce FOND (Free energy Online Natural-gradient Dynamics), a framework that derives neural inference dynamics from three principles: (1) natural gradients on F, (2) online belief updating, and (3) iterative refinement. We apply FOND to derive iP-VAE (iterative Poisson variational autoencoder), a recurrent spiking neural network that performs variational inference through membrane potential dynamics, replacing amortized encoders with iterative inference updates. Theoretically, iP-VAE yields several desirable features such as emergent normalization via lateral competition, and hardware-efficient integer spike count representations. Empirically, iP-VAE outperforms both standard VAEs and Gaussian-based predictive coding models in sparsity, reconstruction, and biological plausibility, and scales to complex color image datasets such as CelebA. iP-VAE also exhibits strong generalization to out-of-distribution inputs, exceeding hybrid iterative-amortized VAEs. These results demonstrate how deriving inference algorithms from first principles can yield concrete architectures that are simultaneously biologically plausible and empirically effective.

Brain-like Variational Inference

TL;DR

This work presents FOND, a principled framework that derives brain-like inference dynamics by performing online natural-gradient descent on variational free energy, unifying neural and machine learning perspectives on inference. It applies FOND to derive iterative VAEs, including the iP-VAE, a spiking model that uses Poisson latents and membrane-potential dynamics to perform online inference with lateral competition. Empirically, iterative VAEs demonstrate stronger reconstruction-sparsity trade-offs, learn cortex-like features, and generalize robustly to out-of-distribution data, while remaining scalable to high-dimensional color images. The results highlight both theoretical coherence with the free energy principle and practical advantages in efficiency and generalization, with promising hardware implications for neuromorphic deployment.

Abstract

Inference in both brains and machines can be formalized by optimizing a shared objective: maximizing the evidence lower bound (ELBO) in machine learning, or minimizing variational free energy (F) in neuroscience (ELBO = -F). While this equivalence suggests a unifying framework, it leaves open how inference is implemented in neural systems. Here, we introduce FOND (Free energy Online Natural-gradient Dynamics), a framework that derives neural inference dynamics from three principles: (1) natural gradients on F, (2) online belief updating, and (3) iterative refinement. We apply FOND to derive iP-VAE (iterative Poisson variational autoencoder), a recurrent spiking neural network that performs variational inference through membrane potential dynamics, replacing amortized encoders with iterative inference updates. Theoretically, iP-VAE yields several desirable features such as emergent normalization via lateral competition, and hardware-efficient integer spike count representations. Empirically, iP-VAE outperforms both standard VAEs and Gaussian-based predictive coding models in sparsity, reconstruction, and biological plausibility, and scales to complex color image datasets such as CelebA. iP-VAE also exhibits strong generalization to out-of-distribution inputs, exceeding hybrid iterative-amortized VAEs. These results demonstrate how deriving inference algorithms from first principles can yield concrete architectures that are simultaneously biologically plausible and empirically effective.

Paper Structure

This paper contains 143 sections, 63 equations, 18 figures, 5 tables, 1 algorithm.

Figures (18)

  • Figure 1: Inferential and dynamical accounts of perception are unified under variational inference. (a) Perception is framed as a dynamical process of convergence to attractors in a neural state space, where membrane potentials evolve and generate spikes along the way. (b) Our prescriptive approach derives neural dynamics by minimizing free energy via natural gradient descent, yielding a spiking network with lateral competition. The resulting architectures are principled and empirically effective. Code, data, and model checkpoints are available here: https://github.com/hadivafaii/IterativeVAE
  • Figure 2: A wide range of models across machine learning and theoretical neuroscience can be unified under free energy ($\mathop{\mathrm{\mathcal{F}}}\nolimits$) minimization. Different distributional and optimization choices result in different models (\ref{['sec:appendix:unify-ml-neuro']}). Motivated by this unification potential, we introduce FOND, a framework for deriving brain-like inference algorithms from first principles (\ref{['sec:fond']}). We apply FOND to derive a family of iterative VAE architectures, including the spiking i$\bm{\mathop{\mathrm{\mathcal{P}}}\nolimits}$-VAE (\ref{['sec:theory']}).
  • Figure 3: All iterative VAEs converge beyond the training regime ($T_\text{train} = 16$). i$\mathop{\mathrm{\mathcal{P}}}\nolimits$-VAE outperforms i$\mathcal{G}_{{\text{relu}}}$-VAE in sparsity, while i$\mathcal{G}$-VAE achieves superior reconstruction but with dense representations. i$\mathop{\mathrm{\mathcal{P}}}\nolimits$-VAE maintains reasonable reconstruction performance, despite using constrained representations (sparse, integer-valued spike count). The sparsification dynamics of both i$\mathop{\mathrm{\mathcal{P}}}\nolimits$-VAE and i$\mathcal{G}_{{\text{relu}}}$-VAE resemble those observed in the mouse visual cortex moosavi2024temporal (\ref{['sec:appendix:cortex-like-dynamics']}). All models are trained on $16 \times 16$ natural image patches, and the traces are averages over the entire test set. The right panel displays representative dictionary elements (${\color{color_dec}{\Phi}}\xspace$); see \ref{['fig:phi-all']} for the complete set of $K=512$ features.
  • Figure 4: Reconstruction-sparsity trade-off across model families. (a) Performance landscape showing reconstruction quality ($R^2$) versus sparsity (proportion of zeros). Different symbols indicate model variants (triangles: $T_\text{train}=8$, crosses: $T_\text{train}=16$, circles: $T_\text{train}=32$, empty squares: amortized VAEs, plus: LCA), with colors denoting model architectures. The gold star marks the theoretical optimum. (b) Overall performance measured as Euclidean distance from the optimum point. Iterative models (right side) consistently outperform their amortized counterparts (left side), with i$\mathop{\mathrm{\mathcal{P}}}\nolimits$-VAE and LCA achieving the best overall performance. See also related \ref{['fig:ratedist_map', 'fig:ratedist_beta']}.
  • Figure 5: Comparison between standard deep residual neural networks (ResNet; he2016resnet) and our inference algorithm. (a) Standard ResNets process information through a series of deterministic layers, each with independent parameters $\phi_t$, where $t$ is the layer index. (b) The i$\mathop{\mathrm{\mathcal{P}}}\nolimits$-VAE encoding algorithm (\ref{['eq:discrete-update']}) unrolled in time resembles a ResNet with shared parameters, but incorporates adaptive, stochastic, and spiking dynamics. This recurrent architecture is conceptually similar to Looped Transformersgiannou23loopedyang2024loopedfan2025looped, which use "input injection"---conditioning on the input $\bm{x}$ at each step---as an effective heuristic to improve stability and performance. In our framework, this mechanism is not a heuristic choice, but emerges directly from the first principles of online variational inference. Finally, unlike conventional networks jastrzebski2018ResIterative, the i$\mathop{\mathrm{\mathcal{P}}}\nolimits$-VAE uses only decoder parameters ${\color{color_dec}{\theta}}\xspace$, enabling online Bayesian inference through iterative sampling and recurrent updates conditioned on both current state ${\color{color_enc}{\bm{u}}}\xspace_t$ and input $\bm{x}$. This figure illustrates the linear decoder implementation used in the main paper. For the general nonlinear decoder case, see \ref{['fig:encoder_unrolled_full']}.
  • ...and 13 more figures