Table of Contents
Fetching ...

Toward stochastic neural computing

Yang Qi, Zhichao Zhu, Yiming Wei, Lu Cao, Zhigang Wang, Jie Zhang, Wenlian Lu, Jianfeng Feng

TL;DR

The paper introduces stochastic neural computing (SNC) by representing high-dimensional spiking activity with second-order statistics and deriving a moment embedding that yields the moment neural network (MNN). The MNN serves as a differentiable surrogate for training spiking neural networks and can reconstruct the original SNN without extra parameters, enabling end-to-end probabilistic inference that accounts for correlated variability. On MNIST and neuromorphic hardware (Loihi), SNC achieves competitive accuracy while reducing inference time and energy by exploiting uncertainty and variability, and extends gracefully to other datasets like Fashion-MNIST and CIFAR-10. This framework bridges SNNs and ANNs through a principled second-order representation, with potential to guide future neuromorphic architectures and learning of uncertainty-aware intelligence.

Abstract

The highly irregular spiking activity of cortical neurons and behavioral variability suggest that the brain could operate in a fundamentally probabilistic way. Mimicking how the brain implements and learns probabilistic computation could be a key to developing machine intelligence that can think more like humans. In this work, we propose a theory of stochastic neural computing (SNC) in which streams of noisy inputs are transformed and processed through populations of nonlinearly coupled spiking neurons. To account for the propagation of correlated neural variability, we derive from first principles a moment embedding for spiking neural network (SNN). This leads to a new class of deep learning model called the moment neural network (MNN) which naturally generalizes rate-based neural networks to second order. As the MNN faithfully captures the stationary statistics of spiking neural activity, it can serve as a powerful proxy for training SNN with zero free parameters. Through joint manipulation of mean firing rate and noise correlations in a task-driven way, the model is able to learn inference tasks while simultaneously minimizing prediction uncertainty, resulting in enhanced inference speed. We further demonstrate the application of our method to Intel's Loihi neuromorphic hardware. The proposed theory of SNC may open up new opportunities for developing machine intelligence capable of computing uncertainty and for designing unconventional computing architectures.

Toward stochastic neural computing

TL;DR

The paper introduces stochastic neural computing (SNC) by representing high-dimensional spiking activity with second-order statistics and deriving a moment embedding that yields the moment neural network (MNN). The MNN serves as a differentiable surrogate for training spiking neural networks and can reconstruct the original SNN without extra parameters, enabling end-to-end probabilistic inference that accounts for correlated variability. On MNIST and neuromorphic hardware (Loihi), SNC achieves competitive accuracy while reducing inference time and energy by exploiting uncertainty and variability, and extends gracefully to other datasets like Fashion-MNIST and CIFAR-10. This framework bridges SNNs and ANNs through a principled second-order representation, with potential to guide future neuromorphic architectures and learning of uncertainty-aware intelligence.

Abstract

The highly irregular spiking activity of cortical neurons and behavioral variability suggest that the brain could operate in a fundamentally probabilistic way. Mimicking how the brain implements and learns probabilistic computation could be a key to developing machine intelligence that can think more like humans. In this work, we propose a theory of stochastic neural computing (SNC) in which streams of noisy inputs are transformed and processed through populations of nonlinearly coupled spiking neurons. To account for the propagation of correlated neural variability, we derive from first principles a moment embedding for spiking neural network (SNN). This leads to a new class of deep learning model called the moment neural network (MNN) which naturally generalizes rate-based neural networks to second order. As the MNN faithfully captures the stationary statistics of spiking neural activity, it can serve as a powerful proxy for training SNN with zero free parameters. Through joint manipulation of mean firing rate and noise correlations in a task-driven way, the model is able to learn inference tasks while simultaneously minimizing prediction uncertainty, resulting in enhanced inference speed. We further demonstrate the application of our method to Intel's Loihi neuromorphic hardware. The proposed theory of SNC may open up new opportunities for developing machine intelligence capable of computing uncertainty and for designing unconventional computing architectures.
Paper Structure (21 sections, 21 equations, 6 figures, 3 tables)

This paper contains 21 sections, 21 equations, 6 figures, 3 tables.

Figures (6)

  • Figure 1: Spike-based stochastic neural computing (SNC).a, A graphical model for SNC which consists of a generative model describing the external environment, multiple stages of fluctuating neural activity states, and a readout for making inference. The index $k$ represents different neural populations in a feedforward network or alternatively time in a recurrent network. b, Schematics of a spiking neural network implementing the computational processes outlined in a, with each layer characterized by a joint probability distribution of neural spike count. c, Propagation of irregular neural spike trains through two feedforwardly connected neural populations. The pre-synaptic spike trains first undergo synaptic summation to generate fluctuating synaptic currents, which in turn drive post-synaptic neurons to fire. The probability distribution of spike count is transformed in a non-trivial way due to the nonlinear coupling of correlated neural variability.
  • Figure 2: Gradient-based learning in spiking neural network through moment embedding. a, Overall schematics. The spiking neural network model is first mapped to a corresponding moment neural network which can be trained with backpropagation; the trained weights are used to recover the original spiking neural network. b, Components of the moment activation function including mean firing rate $\mu$, firing variability $\sigma$, and the linear response coefficient $\chi$, each of which is a function of the input current mean $\bar{\mu}$ and variability $\bar{\sigma}$. In conventional analog and digital computing, such noise coupling is considered detrimental to information carried within the signal. In contrast, stochastic computing actively exploits correlated variability as a part of the computational processes. c, Computational graph of a single feedforward layer of the moment neural network, featuring synaptic summation, moment batch normalization and moment activation. d, Illustration of a trainable moment neural network with a feedforward architecture consisting of an input layer, arbitrary number of hidden layers, a readout layer, and a moment loss function.
  • Figure 3: Moment neural network learning classification task while simultaneously minimizing uncertainty.a, The probability of correct prediction averaged over all samples of the validation set during training for unlimited and limited readout time $\Delta t$; the latter takes into account of trial-to-trial variability. For unlimited readout time ($\Delta t\to\infty$), the accuracy reaches $98.45\%$ at the end of the epochs, comparable to the performance of rate-based artificial neural networks. b, Diverse firing variability of hidden layer neurons in response to the input image shown in e, exhibiting both mean-dominant (Fano factor close to zero) and fluctuation-dominant (Fano factor close to one, solid line) activity. Insets: probability densities. c, Spike count correlation of the hidden layer neurons exhibit weak correlation whose distribution shows a slower decaying tail after training. d, Illustration of non-trivial roles played by correlated variability of a specific pair of hidden layer neurons. An input image represented by independent Poisson spike trains undergoes synaptic summation with anti-correlated weights, leading to anti-correlated neural activity. The final readout is linearly decoded from the hidden layer spike counts. e, The mean (dot) and covariance (ellipse) of the spike count of those two neurons over a readout time $\Delta t=100$ ms. In this example, the principal axis of the covariance is orthogonal to the direction of the readout weights (solid line) with respect to the target class, leading to a reduction in the readout variance and simultaneously an increase in the readout mean.
  • Figure 4: Temporal dynamics of stochastic neural computing in spiking neural network.a, Membrane potential (upper panel) and synaptic current (lower panel) of two typical hidden layer neurons during one trial of stimulus presentation, each exhibiting mean-dominant (left panel) and fluctuation-dominant (right panel) activity. Dashed line indicates firing threshold. b, Raster plot of typical spiking activity of hidden layer neurons during one trial of stimulus presentation. Solid line indicates stimulus onset at $t=0$ ms; shaded region indicates the readout time window $\Delta t$. The membrane potential of all neurons are initialized to zero at $\Delta t=0$ ms. c, Two dimensional projection of the readout trajectory $y$ over time for three trials using the same stimulus. The vertical axis is the readout component corresponding to the correct class. Dot indicates theoretical limit of the readout mean as $\Delta t \to\infty$; dashed line indicates the decision boundary. d, Probability of correct prediction for a number of input images (left panel) as a function of readout time and population spike count in the hidden layer. Each curve is calculated from $100$ trials with the same stimulus. e, Probability of correct prediction averaged over all images of the validation set converges exponentially with the readout time (left panel) as well as the population spike count in the hidden layer (right panel). Dashed lines indicate the theoretical limit of $0.985$ as predicted by the MNN; solid lines represent exponential fits.
  • Figure S1: Performance of SNN deployed on neuromorphic chip. a, Classification accuracy increases with simulation time steps and converges to the theoretical limit predicted by the MNN. Compared to single-precision floating-point simulation on CPU, the simulation on Loihi has a small amount of accuracy loss caused by weight quantization. b, The classification accuracy over time of each model. c, The classification accuracy (at 100 ms of simulation time) for varying hidden layer size and $\Delta t$ used in loss for training. d, The average energy cost per sample (at 100 ms of simulation time) for varying hidden layer sizes and $\Delta t$. e, The average latency (at 100 ms of simulation time) for varying hidden layer size and $\Delta t$. f, Energy-latency diagram (at 80% accuracy) revealing a trade-off between energy cost and latency. The dot size corresponds to hidden layer size and the color corresponds to the value of $\Delta t$ used during training.
  • ...and 1 more figures