Table of Contents
Fetching ...

SpikingBERT: Distilling BERT to Train Spiking Language Models Using Implicit Differentiation

Malyaban Bal, Abhronil Sengupta

TL;DR

This work is the first one to demonstrate the performance of an operational spiking LM architecture on multiple different tasks in the GLUE benchmark, and leverages the average spiking rate of neurons at equilibrium to train a neuromorphic spiking LM using implicit differentiation technique, thereby overcoming the non-differentiability problem of spiking neural network (SNN) based algorithms without using any type of surrogate gradient.

Abstract

Large language Models (LLMs), though growing exceedingly powerful, comprises of orders of magnitude less neurons and synapses than the human brain. However, it requires significantly more power/energy to operate. In this work, we propose a novel bio-inspired spiking language model (LM) which aims to reduce the computational cost of conventional LMs by drawing motivation from the synaptic information flow in the brain. In this paper, we demonstrate a framework that leverages the average spiking rate of neurons at equilibrium to train a neuromorphic spiking LM using implicit differentiation technique, thereby overcoming the non-differentiability problem of spiking neural network (SNN) based algorithms without using any type of surrogate gradient. The steady-state convergence of the spiking neurons also allows us to design a spiking attention mechanism, which is critical in developing a scalable spiking LM. Moreover, the convergence of average spiking rate of neurons at equilibrium is utilized to develop a novel ANN-SNN knowledge distillation based technique wherein we use a pre-trained BERT model as "teacher" to train our "student" spiking architecture. While the primary architecture proposed in this paper is motivated by BERT, the technique can be potentially extended to different kinds of LLMs. Our work is the first one to demonstrate the performance of an operational spiking LM architecture on multiple different tasks in the GLUE benchmark.

SpikingBERT: Distilling BERT to Train Spiking Language Models Using Implicit Differentiation

TL;DR

This work is the first one to demonstrate the performance of an operational spiking LM architecture on multiple different tasks in the GLUE benchmark, and leverages the average spiking rate of neurons at equilibrium to train a neuromorphic spiking LM using implicit differentiation technique, thereby overcoming the non-differentiability problem of spiking neural network (SNN) based algorithms without using any type of surrogate gradient.

Abstract

Large language Models (LLMs), though growing exceedingly powerful, comprises of orders of magnitude less neurons and synapses than the human brain. However, it requires significantly more power/energy to operate. In this work, we propose a novel bio-inspired spiking language model (LM) which aims to reduce the computational cost of conventional LMs by drawing motivation from the synaptic information flow in the brain. In this paper, we demonstrate a framework that leverages the average spiking rate of neurons at equilibrium to train a neuromorphic spiking LM using implicit differentiation technique, thereby overcoming the non-differentiability problem of spiking neural network (SNN) based algorithms without using any type of surrogate gradient. The steady-state convergence of the spiking neurons also allows us to design a spiking attention mechanism, which is critical in developing a scalable spiking LM. Moreover, the convergence of average spiking rate of neurons at equilibrium is utilized to develop a novel ANN-SNN knowledge distillation based technique wherein we use a pre-trained BERT model as "teacher" to train our "student" spiking architecture. While the primary architecture proposed in this paper is motivated by BERT, the technique can be potentially extended to different kinds of LLMs. Our work is the first one to demonstrate the performance of an operational spiking LM architecture on multiple different tasks in the GLUE benchmark.
Paper Structure (27 sections, 13 equations, 5 figures, 2 tables)

This paper contains 27 sections, 13 equations, 5 figures, 2 tables.

Figures (5)

  • Figure 1: High-level overview of the SpikingBERT model. During the "forward" phase of learning, the network is simulated over $T_{conv}$ time steps, i.e., until the ASR of the neurons of each layer converges to an equilibrium. Information flow both within and between two spiking encoders occur using spikes instead of real values, thereby mimicking event-driven information flow in bio-inspired systems.
  • Figure 2: Results obtained after passing a randomly sampled input from SST-2 dataset through SpikingBERT4. (a) Graph showing mean (over number of neurons) of the ASR of different sub-layers in an SE layer against the operating time steps. (b) The y-axis on the left depicts mean (over number of neurons) of the ASR of a randomly chosen spiking attention layer. Along the right y-axis, the "Difference Norm" between the output of the steady-state equation of the chosen spiking attention layer and the calculated ASR is shown. Time steps used for convergence in shown along the x-axis.
  • Figure 3: High-level overview of transformer layer based KD at equilibrium (following Eqn. \ref{['eqn10']}) from a "teacher" LM to a spiking "student" LM.
  • Figure 4: Results obtained on SST-2 dataset. (a) Variation of Accuracy and Energy-efficiency factor ($e$) as $T_{conv}$ increases. (b) Variation in mean ASR per neuron in different sub-layers of SpikingBERT4 following changes in $V_{th}$.
  • Figure A1: (a) Graph illustrating mean (over number of neurons) of the ASR of the IL-1 sub-layer (in an SE layer) against the operating time steps. This visualization contrasts scenarios where the final model employs normalization with those where normalization is omitted. (b) Graph showing accuracy against time steps used for convergence during inference for SpikingBERT models with 2 and 4 SE layers.