Table of Contents
Fetching ...

Uncertainty Quantification in Working Memory via Moment Neural Networks

Hengyuan Ma, Wenlian Lu, Jianfeng Feng

TL;DR

This study applies moment neural networks (MNNs) to explore the neural mechanism of uncertainty quantification in working memory (WM) and offers insights into how the brain effectively manages uncertainty with exceptional accuracy.

Abstract

Humans possess a finely tuned sense of uncertainty that helps anticipate potential errors, vital for adaptive behavior and survival. However, the underlying neural mechanisms remain unclear. This study applies moment neural networks (MNNs) to explore the neural mechanism of uncertainty quantification in working memory (WM). The MNN captures nonlinear coupling of the first two moments in spiking neural networks (SNNs), identifying firing covariance as a key indicator of uncertainty in encoded information. Trained on a WM task, the model demonstrates coding precision and uncertainty quantification comparable to human performance. Analysis reveals a link between the probabilistic and sampling-based coding for uncertainty representation. Transferring the MNN's weights to an SNN replicates these results. Furthermore, the study provides testable predictions demonstrating how noise and heterogeneity enhance WM performance, highlighting their beneficial role rather than being mere biological byproducts. These findings offer insights into how the brain effectively manages uncertainty with exceptional accuracy.

Uncertainty Quantification in Working Memory via Moment Neural Networks

TL;DR

This study applies moment neural networks (MNNs) to explore the neural mechanism of uncertainty quantification in working memory (WM) and offers insights into how the brain effectively manages uncertainty with exceptional accuracy.

Abstract

Humans possess a finely tuned sense of uncertainty that helps anticipate potential errors, vital for adaptive behavior and survival. However, the underlying neural mechanisms remain unclear. This study applies moment neural networks (MNNs) to explore the neural mechanism of uncertainty quantification in working memory (WM). The MNN captures nonlinear coupling of the first two moments in spiking neural networks (SNNs), identifying firing covariance as a key indicator of uncertainty in encoded information. Trained on a WM task, the model demonstrates coding precision and uncertainty quantification comparable to human performance. Analysis reveals a link between the probabilistic and sampling-based coding for uncertainty representation. Transferring the MNN's weights to an SNN replicates these results. Furthermore, the study provides testable predictions demonstrating how noise and heterogeneity enhance WM performance, highlighting their beneficial role rather than being mere biological byproducts. These findings offer insights into how the brain effectively manages uncertainty with exceptional accuracy.

Paper Structure

This paper contains 14 sections, 2 theorems, 43 equations, 12 figures.

Key Result

Theorem S1

Supposed that the model parameter $\bm{\vartheta}$ diminishes the loss $\mathcal{L}(\bm{\vartheta},C_{\xi})$, then the model learns the ground-truth output mean and variance at the same time for any noise conditions with positive-definite ${C}'_{\xi}$, and ${\sigma}'^2_{\eta}$

Figures (12)

  • Figure 1: Working memory task and its uncertainty quantification (UQ). (a) In the task designed in li2021joint, participants are required to remember the location indicated by the cue. After a delay period, they use saccades to indicate the remembered location on the ring and report their uncertainty with an arc. (b) Four representative cases of UQ results. Effective UQ should accurately reflect the magnitude of the error.
  • Figure 2: Comparison of the spiking neural network (SNN), rate-based neural network, and moment neural network (MNN). (a). (Top) An scheme of an SNN. (Bottom) Spike trains of a neuron population can be summarized by the mean firing rate of neurons and the firing covariance between each pair of neurons. (b). (Top) The rate-based neural model only considers the mean firing rate and its nonlinear evolution through an activation function. (Bottom) The MNN captures both the mean firing rate and firing covariance, with their nonlinear coupling during network evolution represented by the moment activations. (c) Training the synaptic connection of an MNN for the working memory task using reservoir computing. (d) The inference procedure is as follows: During the cue period, a feature encoded by external input is sent to the trained MNN. During the delay period, the external inputs are removed. After the delay period, the feature estimation and its uncertainty are decoded from the mean and covariance of the MNN, respectively.
  • Figure 3: The performance of moment neural network (MNN) trained on the working memory task. (a) The tuning curves of five neurons. (b) The trained weights of the network. (c) Six instances of the true input variables and the decoded feature and confidence interval decoded from the network. The arc length of the confidence interval is proportional to the square root of the first eigenvalue of the decoded covariance matrix $\hat{C}_{z}$, see Network inference section in Methods. (d) The error distribution across three groups of instances, divided based on the level of uncertainty (using the uncertainty metrics I defined Eq. \ref{['eq:uncern_1']}, Methods): top 25% uncertainty (low confidence), top 25-50% uncertainty (middle confidence), and the remaining instances (high confidence). Corresponding results of uncertainty metrics II-IV are shown in the Supplementary Information. (e) The correlation between uncertainty (calculated using four indicators, I-IV) and the error, calculated under two conditions: one where the correlation between neuron activities is maintained (w/ corr) and one where the correlation is clamped to zero (w/o corr).
  • Figure 4: Mechanism of uncertainty quantification in the moment neural network (MNN) for working memory. (a) Several typical fixed-point patterns of mean firing rate (top) and firing covariance (bottom) produced by the MNN, with the bump width increasing from left to right. (b) The correlation between the four uncertainty indicators (I-IV, Eq. \ref{['eq:uncern_1']}-Eq. \ref{['eq:uncern_2']}, Methods) and the bump width, calculated under two conditions: one where the correlation between neuron activities is maintained (w/ corr), and one where the correlation is clamped to zero (w/o corr). (c) A schematic showing how differential covariance affects the neural coding of variables with different values. (d) The random drift of the bump generates differential covariance components in the firing covariance. (e) The correlation between uncertainty and the differential covariance ratio (DCR).
  • Figure 5: The performance and spiking neural network (SNN) using the weights of the trained the moment neural network (MNN). (a) The tuning curves of five neurons. The trained weights of the network. (b) Several typical fixed-point patterns of mean firing rate (top) and firing covariance (bottom) produced by the SNN, with the bump width increasing from left to right. (c) The error distribution across three groups of instances, divided based on the level of uncertainty (using the uncertainty metrics I defined Eq. \ref{['eq:uncern_1']}, Methods): top 25% uncertainty (low confidence), top 25-50% uncertainty (middle confidence), and the remaining instances (high confidence). Corresponding results of uncertainty metrics II-IV are shown in the Supplementary Information. (d) The correlation between uncertainty (calculated using four indicators, I-IV) and the error. (e) The correlation between the four uncertainty indicators (I-IV, Eq. \ref{['eq:uncern_1']}-Eq. \ref{['eq:uncern_2']}, Methods) and the bump width. (f) The correlation between uncertainty and the differential covariance ratio (DCR).
  • ...and 7 more figures

Theorems & Definitions (4)

  • Theorem S1
  • proof
  • Theorem S2
  • proof