Table of Contents
Fetching ...

A differentiable brain simulator bridging brain simulation and brain-inspired computing

Chaoming Wang, Tianqiu Zhang, Sichao He, Hongyaoxing Gu, Shangyang Li, Si Wu

TL;DR

BrainPy introduces a differentiable brain simulator built on JAX/XLA to bridge brain simulation and brain-inspired computing (BIC), addressing the lack of differentiability in traditional simulators and the limited biophysical realism in DL-based BIC libraries. It achieves this through dedicated sparse and event-driven operators, a novel synapse-projection abstraction, a multi-scale modular interface, and object-oriented JIT compilation, enabling scalable, differentiable brain dynamics within the AI ecosystem. Key contributions include AlignPre/AlignPost projections for memory-efficient synaptic computations, JIT connectivity for large-scale networks, and seamless integration with JAX ML tooling to train biologically plausible spiking models. The work demonstrates substantial efficiency and scalability gains, supports differentiable training of spiking networks, and offers a practical platform for interdisciplinary research at the intersection of brain simulation and brain-inspired computing.

Abstract

Brain simulation builds dynamical models to mimic the structure and functions of the brain, while brain-inspired computing (BIC) develops intelligent systems by learning from the structure and functions of the brain. The two fields are intertwined and should share a common programming framework to facilitate each other's development. However, none of the existing software in the fields can achieve this goal, because traditional brain simulators lack differentiability for training, while existing deep learning (DL) frameworks fail to capture the biophysical realism and complexity of brain dynamics. In this paper, we introduce BrainPy, a differentiable brain simulator developed using JAX and XLA, with the aim of bridging the gap between brain simulation and BIC. BrainPy expands upon the functionalities of JAX, a powerful AI framework, by introducing complete capabilities for flexible, efficient, and scalable brain simulation. It offers a range of sparse and event-driven operators for efficient and scalable brain simulation, an abstraction for managing the intricacies of synaptic computations, a modular and flexible interface for constructing multi-scale brain models, and an object-oriented just-in-time compilation approach to handle the memory-intensive nature of brain dynamics. We showcase the efficiency and scalability of BrainPy on benchmark tasks, highlight its differentiable simulation for biologically plausible spiking models, and discuss its potential to support research at the intersection of brain simulation and BIC.

A differentiable brain simulator bridging brain simulation and brain-inspired computing

TL;DR

BrainPy introduces a differentiable brain simulator built on JAX/XLA to bridge brain simulation and brain-inspired computing (BIC), addressing the lack of differentiability in traditional simulators and the limited biophysical realism in DL-based BIC libraries. It achieves this through dedicated sparse and event-driven operators, a novel synapse-projection abstraction, a multi-scale modular interface, and object-oriented JIT compilation, enabling scalable, differentiable brain dynamics within the AI ecosystem. Key contributions include AlignPre/AlignPost projections for memory-efficient synaptic computations, JIT connectivity for large-scale networks, and seamless integration with JAX ML tooling to train biologically plausible spiking models. The work demonstrates substantial efficiency and scalability gains, supports differentiable training of spiking networks, and offers a practical platform for interdisciplinary research at the intersection of brain simulation and brain-inspired computing.

Abstract

Brain simulation builds dynamical models to mimic the structure and functions of the brain, while brain-inspired computing (BIC) develops intelligent systems by learning from the structure and functions of the brain. The two fields are intertwined and should share a common programming framework to facilitate each other's development. However, none of the existing software in the fields can achieve this goal, because traditional brain simulators lack differentiability for training, while existing deep learning (DL) frameworks fail to capture the biophysical realism and complexity of brain dynamics. In this paper, we introduce BrainPy, a differentiable brain simulator developed using JAX and XLA, with the aim of bridging the gap between brain simulation and BIC. BrainPy expands upon the functionalities of JAX, a powerful AI framework, by introducing complete capabilities for flexible, efficient, and scalable brain simulation. It offers a range of sparse and event-driven operators for efficient and scalable brain simulation, an abstraction for managing the intricacies of synaptic computations, a modular and flexible interface for constructing multi-scale brain models, and an object-oriented just-in-time compilation approach to handle the memory-intensive nature of brain dynamics. We showcase the efficiency and scalability of BrainPy on benchmark tasks, highlight its differentiable simulation for biologically plausible spiking models, and discuss its potential to support research at the intersection of brain simulation and BIC.
Paper Structure (44 sections, 23 equations, 16 figures, 2 tables)

This paper contains 44 sections, 23 equations, 16 figures, 2 tables.

Figures (16)

  • Figure 1: The overview of BrainPy architecture.
  • Figure 2: Synaptic projections in BrainPy. (A) The and projections offer a decoupled interface for managing dynamics and the communication between dynamics. (B) The synaptic communication allows for diverse implementations, including the utilization of DL models.
  • Figure 3: Multi-scale model building interface of BrainPy. Here is referred to package.
  • Figure 4: Event-driven operators in BrainPy enable the efficient running of brain simulation models. (A, B) Speed comparison between different operators that perform the matrix-vector multiplication for synaptic computation on both CPU (A) and GPU (B) devices. (C, D) Speed comparison of different brain simulators when simulating the COBA-LIF (C) and COBA-HH networks (D).
  • Figure 5: JIT connectivity operators enable large-scale brain dynamics modeling. (A, B) The memory usage (A) and speed (B) comparison between BrainPy's JIT connectivity operator with JAX's sparse and dense operators for performing the matrix multiplication. (C) Scaling up the COBA-LIF network with the JIT connectivity operator. (D, E) The empirical relationship between the classification performance and the reservoir size using KTH (D) and MNIST (E) datasets. (F) The inference speed comparison of the reservoir model using the dense, sparse, and JIT connectivity operators.
  • ...and 11 more figures