Table of Contents
Fetching ...

JPC: Flexible Inference for Predictive Coding Networks in JAX

Francesco Innocenti, Paul Kinghorn, Will Yun-Farmbrough, Miguel De Llanza Varona, Ryan Singh, Christopher L. Buckley

TL;DR

JPC addresses the lack of open-source tooling for predictive coding networks by delivering a compact, JAX-based library that uses ODE solvers to perform gradient-flow inference. It supports discriminative, generative, and hybrid PCN training through a simple API and advanced customization options. Experiments show that a second-order solver (Heun) speeds up runtime relative to Euler while maintaining accuracy across multiple depths and datasets. The library also provides theoretical tools to diagnose inference sufficiency in PCNs, including a closed-form equilibrium energy for deep linear PCNs. Overall, JPC lowers barriers to PCN research and provides a runnable reference for solver-based PC inference.

Abstract

We introduce JPC, a JAX library for training neural networks with Predictive Coding. JPC provides a simple, fast and flexible interface to train a variety of PC networks (PCNs) including discriminative, generative and hybrid models. Unlike existing libraries, JPC leverages ordinary differential equation solvers to integrate the gradient flow inference dynamics of PCNs. We find that a second-order solver achieves significantly faster runtimes compared to standard Euler integration, with comparable performance on a range of tasks and network depths. JPC also provides some theoretical tools that can be used to study PCNs. We hope that JPC will facilitate future research of PC. The code is available at https://github.com/thebuckleylab/jpc.

JPC: Flexible Inference for Predictive Coding Networks in JAX

TL;DR

JPC addresses the lack of open-source tooling for predictive coding networks by delivering a compact, JAX-based library that uses ODE solvers to perform gradient-flow inference. It supports discriminative, generative, and hybrid PCN training through a simple API and advanced customization options. Experiments show that a second-order solver (Heun) speeds up runtime relative to Euler while maintaining accuracy across multiple depths and datasets. The library also provides theoretical tools to diagnose inference sufficiency in PCNs, including a closed-form equilibrium energy for deep linear PCNs. Overall, JPC lowers barriers to PCN research and provides a runnable reference for solver-based PC inference.

Abstract

We introduce JPC, a JAX library for training neural networks with Predictive Coding. JPC provides a simple, fast and flexible interface to train a variety of PC networks (PCNs) including discriminative, generative and hybrid models. Unlike existing libraries, JPC leverages ordinary differential equation solvers to integrate the gradient flow inference dynamics of PCNs. We find that a second-order solver achieves significantly faster runtimes compared to standard Euler integration, with comparable performance on a range of tasks and network depths. JPC also provides some theoretical tools that can be used to study PCNs. We hope that JPC will facilitate future research of PC. The code is available at https://github.com/thebuckleylab/jpc.

Paper Structure

This paper contains 10 sections, 4 equations, 7 figures.

Figures (7)

  • Figure 1: Second-order Runge–Kutta method (Heun) solves PC inference faster than standard Euler on a range of datasets and networks. We plot the wall-clock time of Euler and Heun at each training step of one epoch for networks with hidden layers $H \in \{3, 5, 10\}$ trained on standard image classification datasets. The runs with the highest mean test accuracy achieved across different hyperparameters were selected (see Figures \ref{['fig3']}-\ref{['fig6']}). The time of the first training iteration where "just-in-time" (jit) compilation occurs is excluded. All networks had 300 hidden units and Tanh as activation function, and were trained with learning rate $1e^{-3}$ and batch size $64$. Shaded regions indicate $\pm1$ standard deviation across 3 different random weight initialisations.
  • Figure 2: Theoretical PC energy for deep linear networks (Eq. \ref{['eq4']}) can help predict whether more inference could lead to better performance. We compare the theoretical energy with the numerical energy for different upper limits $t$ of inference integration, as well as test accuracies, for a 10-hidden-layer, 300-width linear network trained to classify MINST with learning rate $1e^{-3}$ and batch size $64$. Results were consistent across different random initialisations.
  • Figure 3: Test accuracies for Figure \ref{['fig1']}. These accuracies were selected from Figures \ref{['fig4']}-\ref{['fig6']} based on the lowest upper integration limit $T$ at which the maximum mean accuracy was achieved. Note that the experiments were not optimised for accuracy, since we were specifically interested in the runtime of different ODE solvers at comparable performance. We refer to pinchetti2024benchmarking for a comprehensive performance benchmarking of PCNs.
  • Figure 4: Maximum mean test accuracy on MNIST achieved with Euler and Heun as a function of different step sizes $dt$ and upper integration limits $T$. For the results in Figure \ref{['fig1']} with $H=3$, we selected runs with $T=20$, and $dt=0.5$ for Euler and $dt=0.05$ for Heun. For $H=5$, we selected $T=50$, and $dt=0.5$ for Euler and $dt=0.05$ for Heun. Finally, for $H=10$, $T=200$ and $dt=0.05$ were chosen for both solvers.
  • Figure 5: Same results as Figure \ref{['fig4']} for Fashion-MNIST. For the results in Figure \ref{['fig1']} with $H=3$, we selected runs with $T=20$, and $dt=0.5$ for Euler and $dt=0.1$ for Heun. For the other network depths, the same hyperparameters were chosen for both solvers: $T=200$ and $dt=0.5$ for $H=5$, and $T=200$, and $dt=0.05$ for $H=10$.
  • ...and 2 more figures