Table of Contents
Fetching ...

A Variational Latent Equilibrium for Learning in Cortex

Simon Brandt, Paul Haider, Walter Senn, Federico Benitez, Mihai A. Petrovici

TL;DR

The theory provides a rigorous framework for spatiotemporal deep learning in the brain, while simultaneously suggesting a blueprint for physical circuits capable of carrying out these computations.

Abstract

Brains remain unrivaled in their ability to recognize and generate complex spatiotemporal patterns. While AI is able to reproduce some of these capabilities, deep learning algorithms remain largely at odds with our current understanding of brain circuitry and dynamics. This is prominently the case for backpropagation through time (BPTT), the go-to algorithm for learning complex temporal dependencies. In this work we propose a general formalism to approximate BPTT in a controlled, biologically plausible manner. Our approach builds on, unifies and extends several previous approaches to local, time-continuous, phase-free spatiotemporal credit assignment based on principles of energy conservation and extremal action. Our starting point is a prospective energy function of neuronal states, from which we calculate real-time error dynamics for time-continuous neuronal networks. In the general case, this provides a simple and straightforward derivation of the adjoint method result for neuronal networks, the time-continuous equivalent to BPTT. With a few modifications, we can turn this into a fully local (in space and time) set of equations for neuron and synapse dynamics. Our theory provides a rigorous framework for spatiotemporal deep learning in the brain, while simultaneously suggesting a blueprint for physical circuits capable of carrying out these computations. These results reframe and extend the recently proposed Generalized Latent Equilibrium (GLE) model.

A Variational Latent Equilibrium for Learning in Cortex

TL;DR

The theory provides a rigorous framework for spatiotemporal deep learning in the brain, while simultaneously suggesting a blueprint for physical circuits capable of carrying out these computations.

Abstract

Brains remain unrivaled in their ability to recognize and generate complex spatiotemporal patterns. While AI is able to reproduce some of these capabilities, deep learning algorithms remain largely at odds with our current understanding of brain circuitry and dynamics. This is prominently the case for backpropagation through time (BPTT), the go-to algorithm for learning complex temporal dependencies. In this work we propose a general formalism to approximate BPTT in a controlled, biologically plausible manner. Our approach builds on, unifies and extends several previous approaches to local, time-continuous, phase-free spatiotemporal credit assignment based on principles of energy conservation and extremal action. Our starting point is a prospective energy function of neuronal states, from which we calculate real-time error dynamics for time-continuous neuronal networks. In the general case, this provides a simple and straightforward derivation of the adjoint method result for neuronal networks, the time-continuous equivalent to BPTT. With a few modifications, we can turn this into a fully local (in space and time) set of equations for neuron and synapse dynamics. Our theory provides a rigorous framework for spatiotemporal deep learning in the brain, while simultaneously suggesting a blueprint for physical circuits capable of carrying out these computations. These results reframe and extend the recently proposed Generalized Latent Equilibrium (GLE) model.
Paper Structure (13 sections, 40 equations, 3 figures, 2 tables, 1 algorithm)

This paper contains 13 sections, 40 equations, 3 figures, 2 tables, 1 algorithm.

Figures (3)

  • Figure 1: Learning a simple chain in a student-teacher setup.(a) Signal (bottom-up) and error (top-down) neurons, their incoming and outgoing firing rates and their dynamics. (b) A chain of two neurons as the simplest network structure for studying error backpropagation. (c) The network converges to a minimal loss if the backwards weights $\bm B$ are learned (blue) or fixed to $\bm{B}=\bm{W}^T$ (orange). For constant $\bm B$ (purple), the network does not converge. We plot a moving average of the loss to remove high-frequency oscillations. (d-f) The student (opaque) and teacher (transparent) weights. Student $\bm W$ converge to the teacher weights if $\bm B$ is learned (d) or $\bm{B}=\bm{W}^T$ (e) and diverge for constant $\bm B$ (f). The learned backward weights $\bm B$ (d, pink) converge to the expected value given the time constant of the neuron and the frequency of the signal. (g) compares the output of the three scenarios before learning, during learning and after learning to the target signal.
  • Figure 2: Learning in a network of neurons.(a) Forward network structure (error neurons omitted for readability). (b) Network loss over time, evaluated on multiple test signals with randomly sampled frequencies and over multiple seeds. Orange: backward weights follow exact transposes of the forward weights. Blue: backward weights learn using our local update rule (\ref{['eqn:gal']}). (c) Training input, consisting of multiple overlayed sines of different frequency, presented to the network at different times. (d) Output of the student network learning $B$ (solid) compared to the teacher network (dashed) before training, during training and after training.
  • Figure 3: Temporal XOR task.(a) The network consists of two laglines that perform the temporal computation and delay the input signal and one XOR layer that does the spatial computation. The target is generated by a teacher network with weights shown along the edges. (b) The training loss when learning $\bm B$ converges faster compared to using $\bm{B}=\bm{W}^T$, plotted as a moving average to remove high frequency oscillations. (c) The two input signals, each with a different frequency, for different times during the simulation. (d) The output of the model with learned $\bm B$ (blue) compared to the target before, during and after training.