Table of Contents
Fetching ...

Meta predictive learning model of languages in neural circuits

Chan Li, Junbin Qiu, Haiping Huang

TL;DR

This work develops mean-field meta predictive learning (MPL), a brain-inspired predictive coding framework in recurrent networks where all synaptic weights follow spike-and-slab distributions and only distribution parameters are learned. By minimizing a variational free energy, MPL combines inference, learning, and prediction phases, producing an ensemble of networks whose weights become increasingly deterministic except for the readout layer, which remains more variable. MPL is tested on MNIST with sequential pixels, a toy language task, and the Penn Treebank corpus, revealing a data-load driven phase transition near $\alpha_c \approx 0.02$ and the ability to generate grammatically coherent text after sufficient training. The results suggest a plausible link between brain-like next-token prediction, phase-transition dynamics, and emergent language capabilities, while highlighting gaps relative to transformer architectures and pointing to avenues for integrating attention-like mechanisms in a biologically plausible framework.

Abstract

Large language models based on self-attention mechanisms have achieved astonishing performances not only in natural language itself, but also in a variety of tasks of different nature. However, regarding processing language, our human brain may not operate using the same principle. Then, a debate is established on the connection between brain computation and artificial self-supervision adopted in large language models. One of most influential hypothesis in brain computation is the predictive coding framework, which proposes to minimize the prediction error by local learning. However, the role of predictive coding and the associated credit assignment in language processing remains unknown. Here, we propose a mean-field learning model within the predictive coding framework, assuming that the synaptic weight of each connection follows a spike and slab distribution, and only the distribution, rather than specific weights, is trained. This meta predictive learning is successfully validated on classifying handwritten digits where pixels are input to the network in sequence, and moreover on the toy and real language corpus. Our model reveals that most of the connections become deterministic after learning, while the output connections have a higher level of variability. The performance of the resulting network ensemble changes continuously with data load, further improving with more training data, in analogy with the emergent behavior of large language models. Therefore, our model provides a starting point to investigate the connection among brain computation, next-token prediction and general intelligence.

Meta predictive learning model of languages in neural circuits

TL;DR

This work develops mean-field meta predictive learning (MPL), a brain-inspired predictive coding framework in recurrent networks where all synaptic weights follow spike-and-slab distributions and only distribution parameters are learned. By minimizing a variational free energy, MPL combines inference, learning, and prediction phases, producing an ensemble of networks whose weights become increasingly deterministic except for the readout layer, which remains more variable. MPL is tested on MNIST with sequential pixels, a toy language task, and the Penn Treebank corpus, revealing a data-load driven phase transition near and the ability to generate grammatically coherent text after sufficient training. The results suggest a plausible link between brain-like next-token prediction, phase-transition dynamics, and emergent language capabilities, while highlighting gaps relative to transformer architectures and pointing to avenues for integrating attention-like mechanisms in a biologically plausible framework.

Abstract

Large language models based on self-attention mechanisms have achieved astonishing performances not only in natural language itself, but also in a variety of tasks of different nature. However, regarding processing language, our human brain may not operate using the same principle. Then, a debate is established on the connection between brain computation and artificial self-supervision adopted in large language models. One of most influential hypothesis in brain computation is the predictive coding framework, which proposes to minimize the prediction error by local learning. However, the role of predictive coding and the associated credit assignment in language processing remains unknown. Here, we propose a mean-field learning model within the predictive coding framework, assuming that the synaptic weight of each connection follows a spike and slab distribution, and only the distribution, rather than specific weights, is trained. This meta predictive learning is successfully validated on classifying handwritten digits where pixels are input to the network in sequence, and moreover on the toy and real language corpus. Our model reveals that most of the connections become deterministic after learning, while the output connections have a higher level of variability. The performance of the resulting network ensemble changes continuously with data load, further improving with more training data, in analogy with the emergent behavior of large language models. Therefore, our model provides a starting point to investigate the connection among brain computation, next-token prediction and general intelligence.
Paper Structure (11 sections, 16 equations, 6 figures, 1 table, 2 algorithms)

This paper contains 11 sections, 16 equations, 6 figures, 1 table, 2 algorithms.

Figures (6)

  • Figure 1: The performance of meta predictive learning on the 28 by 28 MNIST classification task. (a) Test accuracy as a function of epoch. The network with $N=100$ recurrent neurons, $N_{\rm in}=28$ input units and $N_{\rm out} = 10$ output nodes is trained on the full MNIST dataset with $60$k training images (handwritten digits), and validated on another unseen $10$k test handwritten digits. Predictive coding indicates the learning direct in the weight space rather than the distribution space. If the epoch is less than 40, the number of inference steps is set to $n=100$, and $n=200$ otherwise. The inset shows how $\ln \mathcal{F}$ changes with training in the first 60 training epochs (this log-energy becomes stable in the late training stage, and is thus not shown). Five independent runs are considered for the fluctuation of the result. (b) The logarithmic average value of $[\bm{\Xi}^\ell, \bm{\pi}^\ell,\mathbf{m}^\ell]$ versus epoch in all layers, the log means logarithm with the base $e$. Only the first twenty epochs are considered (the result remains stable in the later training stage), and the fluctuation is computed from five independent runs.
  • Figure 2: The properties of meta predictive learning on the simplified language prediction task. The grammatical rule is designed as follows: starting from a random letter ($'a'$ here), only the candidates located two letters or four letters after $'a'$ can follow the starting letter with equal probability, and each letter only repeats once in this next-word generation. All letters in the alphabet form a cyclic structure. $T=11$ is considered, and the full size of dataset is $26624$. RNN with $N=100, N_{\rm in}=26, N_{\rm out} = 26$ is trained, and two instances of networks are randomly sampled from the (trained or untrained) network ensemble. (a) Starting from the letter a, the network generates the next letter which serves as the input at the next time step, until a sequence with desired length is generated. (b) The correct letter ratio as a function of data load $\alpha = \frac{M}{N}$, and five independent runs are considered. $M$ examples of sequences are used for training. A chance level of $\frac{1}{13}$ is marked. The inset shows the correct letter ratio in the range of $\alpha\in [0.02,0.1]$. (c) The log-energy $\ln \mathcal{F}$ changes with training epochs and decreases to near zero. The inset shows how the correct letter ratio changes with the length of generated sequence after a full dataset is used for training. The error bar is computed with five independent networks.
  • Figure 3: Softmax values of the output units for different data load $\alpha$. Panels (a,b), (c,d), (e,f) and (g,h) show two typical patterns for each data load $\alpha = 0$, $\alpha = 0.01$, $\alpha = 0.03$, and $\alpha = 0.05$, respectively. Only predictions following the designed language rule are displayed, and the text shown in the panel $"a\to c"$ means inputting the letter $'a'$ and the network predicts the immediate following letter $'c'$ (corresponding to the largest softmax output). The training conditions are the same as in Fig. \ref{['fig2']}.
  • Figure 4: Illustration of hyperparameters $[\pi, m, \Xi]$ in meta predictive learning on the simplified language task. The training conditions are the same as in Fig. \ref{['fig2']}. In (c-d), we show statistical properties of bidirectional connections, and $i<j$ is considered.
  • Figure 5: Training performance of networks with different architectures in Penn Treebank dataset. In the upper part of the figure, we choose the vanilla RNN Huang-2022, SaS RNN (ensemble learning) Zou-2023, RNN with standard predictive coding Exact-2021 and RNN with meta predictive learning to show how test perplexity decreases with the training epoch. The first two algorithms belong to the backpropagation through time category Huang-2022. In the inset, we provide the performance of transformer model (see details in appendix \ref{['app-transformer']}) with single encoder block for comparison. We also mark the mean test accuracy of the transformer model at the beginning of training and at the end of the training. In the bottom part of the figure, we select untrained, trained-for-five-epoch, and full-trained RNN with meta predictive learning to show the performances at different training stages in generating one of the sentences in the test dataset. The correctly predicted tokens from the test sentence are highlighted, while the wrongly predicted tokens are gray colored. The indicated accuracy is the ratio of the number of correctly predicted tokens from the test sentence to the total number of tokens in the sentence. The mean accuracy evaluated from $100$ sentences is about $0\%$, $21.3\%\pm10.5\%$, $23.5\%\pm11.3\%$ at the three shown stages, respectively. Note that all the models share the same training hyperparameters like batch size, learning rate, and training optimizers (see appendix \ref{['app-transformer']} for details).
  • ...and 1 more figures