MD-SNN: Membrane Potential-aware Distillation on Quantized Spiking Neural Network
Donghyun Lee, Abhishek Moitra, Youngeun Kim, Ruokai Yin, Priyadarshini Panda
TL;DR
MD-SNN addresses accuracy loss in quantized spiking networks by transferring membrane potential distributions from a full-precision teacher to quantized students. It introduces membrane-aware distillation with dual pathways (membrane potentials and logits) and a versatile teacher framework that supports multiple timesteps from a single teacher. The approach achieves competitive accuracy on static and neuromorphic datasets and delivers substantial hardware efficiency gains on SpikeSim, including up to 14.85x EDAP reduction and improved energy-per-operation metrics. This enables flexible, energy-efficient deployment of quantized SNNs across varying latency-accuracy requirements without retraining multiple models.
Abstract
Spiking Neural Networks (SNNs) offer a promising and energy-efficient alternative to conventional neural networks, thanks to their sparse binary activation. However, they face challenges regarding memory and computation overhead due to complex spatio-temporal dynamics and the necessity for multiple backpropagation computations across timesteps during training. To mitigate this overhead, compression techniques such as quantization are applied to SNNs. Yet, naively applying quantization to SNNs introduces a mismatch in membrane potential, a crucial factor for the firing of spikes, resulting in accuracy degradation. In this paper, we introduce Membrane-aware Distillation on quantized Spiking Neural Network (MD-SNN), which leverages membrane potential to mitigate discrepancies after weight, membrane potential, and batch normalization quantization. To our knowledge, this study represents the first application of membrane potential knowledge distillation in SNNs. We validate our approach on various datasets, including CIFAR10, CIFAR100, N-Caltech101, and TinyImageNet, demonstrating its effectiveness for both static and dynamic data scenarios. Furthermore, for hardware efficiency, we evaluate the MD-SNN with SpikeSim platform, finding that MD-SNNs achieve 14.85X lower energy-delay-area product (EDAP), 2.64X higher TOPS/W, and 6.19X higher TOPS/mm2 compared to floating point SNNs at iso-accuracy on N-Caltech101 dataset.
