Table of Contents
Fetching ...

SpikeBERT: A Language Spikformer Learned from BERT with Knowledge Distillation

Changze Lv, Tianlong Li, Jianhan Xu, Chenxi Gu, Zixuan Ling, Cenyuan Zhang, Xiaoqing Zheng, Xuanjing Huang

TL;DR

This work tackles the energy inefficiency of large language models by developing SpikeBERT, a spiking-language model built on an adapted Spikformer. It introduces a two-stage knowledge distillation pipeline: stage 1 pre-trains SpikeBERT to align embeddings and hidden features with a BERT teacher on unlabeled text, and stage 2 distills task-specific knowledge from a fine-tuned BERT with data augmentation and both logits and cross-entropy losses. Empirically, SpikeBERT achieves state-of-the-art results among SNNs and competitive performance with BERT on six English and Chinese text classification tasks, while consuming substantially less energy (about 27.8% of BERT on average on 45 nm hardware). This demonstrates the feasibility of transferring large language model knowledge to spiking architectures, offering a path toward energy-efficient NLP on neuromorphic hardware.

Abstract

Spiking neural networks (SNNs) offer a promising avenue to implement deep neural networks in a more energy-efficient way. However, the network architectures of existing SNNs for language tasks are still simplistic and relatively shallow, and deep architectures have not been fully explored, resulting in a significant performance gap compared to mainstream transformer-based networks such as BERT. To this end, we improve a recently-proposed spiking Transformer (i.e., Spikformer) to make it possible to process language tasks and propose a two-stage knowledge distillation method for training it, which combines pre-training by distilling knowledge from BERT with a large collection of unlabelled texts and fine-tuning with task-specific instances via knowledge distillation again from the BERT fine-tuned on the same training examples. Through extensive experimentation, we show that the models trained with our method, named SpikeBERT, outperform state-of-the-art SNNs and even achieve comparable results to BERTs on text classification tasks for both English and Chinese with much less energy consumption. Our code is available at https://github.com/Lvchangze/SpikeBERT.

SpikeBERT: A Language Spikformer Learned from BERT with Knowledge Distillation

TL;DR

This work tackles the energy inefficiency of large language models by developing SpikeBERT, a spiking-language model built on an adapted Spikformer. It introduces a two-stage knowledge distillation pipeline: stage 1 pre-trains SpikeBERT to align embeddings and hidden features with a BERT teacher on unlabeled text, and stage 2 distills task-specific knowledge from a fine-tuned BERT with data augmentation and both logits and cross-entropy losses. Empirically, SpikeBERT achieves state-of-the-art results among SNNs and competitive performance with BERT on six English and Chinese text classification tasks, while consuming substantially less energy (about 27.8% of BERT on average on 45 nm hardware). This demonstrates the feasibility of transferring large language model knowledge to spiking architectures, offering a path toward energy-efficient NLP on neuromorphic hardware.

Abstract

Spiking neural networks (SNNs) offer a promising avenue to implement deep neural networks in a more energy-efficient way. However, the network architectures of existing SNNs for language tasks are still simplistic and relatively shallow, and deep architectures have not been fully explored, resulting in a significant performance gap compared to mainstream transformer-based networks such as BERT. To this end, we improve a recently-proposed spiking Transformer (i.e., Spikformer) to make it possible to process language tasks and propose a two-stage knowledge distillation method for training it, which combines pre-training by distilling knowledge from BERT with a large collection of unlabelled texts and fine-tuning with task-specific instances via knowledge distillation again from the BERT fine-tuned on the same training examples. Through extensive experimentation, we show that the models trained with our method, named SpikeBERT, outperform state-of-the-art SNNs and even achieve comparable results to BERTs on text classification tasks for both English and Chinese with much less energy consumption. Our code is available at https://github.com/Lvchangze/SpikeBERT.
Paper Structure (34 sections, 19 equations, 4 figures, 5 tables)

This paper contains 34 sections, 19 equations, 4 figures, 5 tables.

Figures (4)

  • Figure 1: (a) Architecture of Spikformer Zhou2022SpikformerWS with $L$ encoder blocks. Spikformer is specially designed for image classification task, where spiking patch splitting (SPS) module and "convolution layer + batch normalization" module can process vision signals well. And the spiking self attention (SSA) module in Spikformer aims to model the attention between every two dimensions so that we denote it as "D-SSA". (b) Architecture of SpikeBERT with $L^{'}$ encoder blocks. In order to improve the model's ability of processing texts, we adopt "linear layer + layer normalization", and also replace the SPS module with a word embedding layer. Furthermore, we modify the SSA module to enhance SpikeBERT's ability to concentrate on the interrelation between all pairs of words (or tokens), instead of dimensions.
  • Figure 2: Overview of our two-stage distillation method (pre-training + task-specific distillation) for training SpikeBERT. $T$ is the number of time steps of features in every layer. Notice that the logits loss and cross-entropy loss are only considered in stage $2$. The varying shades of color represent the magnitude of the floating-point values. The dotted line under $L_{fea}^i$ indicates that features of some hidden layers can be ignored when calculating feature alignment loss. If the student model contains different numbers of layers from the teacher model, we will align features every few layers.
  • Figure 3: (a) Accuracy versus the number of time steps. (b) Accuracy versus the depth of networks. (c) Accuracy versus the decay rate $\beta$.
  • Figure 4: Attention map examples of SSA. (a) Content of the sample is: "pays earnest homage to turntablists and beat jugglers, old schoolers and current innovators". It is positive. (b) Content of the sample is: "is a pan-american movie, with moments of genuine insight into the urban heart .". It is positive.