JPC: Flexible Inference for Predictive Coding Networks in JAX
Francesco Innocenti, Paul Kinghorn, Will Yun-Farmbrough, Miguel De Llanza Varona, Ryan Singh, Christopher L. Buckley
TL;DR
JPC addresses the lack of open-source tooling for predictive coding networks by delivering a compact, JAX-based library that uses ODE solvers to perform gradient-flow inference. It supports discriminative, generative, and hybrid PCN training through a simple API and advanced customization options. Experiments show that a second-order solver (Heun) speeds up runtime relative to Euler while maintaining accuracy across multiple depths and datasets. The library also provides theoretical tools to diagnose inference sufficiency in PCNs, including a closed-form equilibrium energy for deep linear PCNs. Overall, JPC lowers barriers to PCN research and provides a runnable reference for solver-based PC inference.
Abstract
We introduce JPC, a JAX library for training neural networks with Predictive Coding. JPC provides a simple, fast and flexible interface to train a variety of PC networks (PCNs) including discriminative, generative and hybrid models. Unlike existing libraries, JPC leverages ordinary differential equation solvers to integrate the gradient flow inference dynamics of PCNs. We find that a second-order solver achieves significantly faster runtimes compared to standard Euler integration, with comparable performance on a range of tasks and network depths. JPC also provides some theoretical tools that can be used to study PCNs. We hope that JPC will facilitate future research of PC. The code is available at https://github.com/thebuckleylab/jpc.
