Learn To be Efficient: Build Structured Sparsity in Large Language Models
Haizhong Zheng, Xiaoyan Bai, Xueshen Liu, Z. Morley Mao, Beidi Chen, Fan Lai, Atul Prakash
TL;DR
This work tackles the high inference cost of large language models by exploiting activation sparsity through learned, structured contextual sparsity. It introduces Learn-To-be-Efficient (LTE), a two-stage training method that groups FFN neurons into 32-neuron experts and trains a threshold-based Sigmoid router with efficiency and separability losses, enabling adaptive sparsity across layers. LTE delivers up to 1.83x–2.59x FLOPs speedups on LLaMA-7B and achieves about 25% wall-clock latency reduction at 50% sparsity when paired with a hardware-aware Triton kernel, outperforming MoEfication and Deja Vu baselines across NLU, NLG, and instruction-tuning tasks. The approach demonstrates effective sparsity learning even with soft activations and provides a practical path toward efficient, scalable LLM inference.
Abstract
Large Language Models (LLMs) have achieved remarkable success with their billion-level parameters, yet they incur high inference overheads. The emergence of activation sparsity in LLMs provides a natural approach to reduce this cost by involving only parts of the parameters for inference. However, existing methods only focus on utilizing this naturally formed activation sparsity in a post-training setting, overlooking the potential for further amplifying this inherent sparsity. In this paper, we hypothesize that LLMs can learn to be efficient by achieving more structured activation sparsity. To achieve this, we introduce a novel training algorithm, Learn-To-be-Efficient (LTE), designed to train efficiency-aware LLMs to learn to activate fewer neurons and achieve a better trade-off between sparsity and performance. Furthermore, unlike SOTA MoEfication methods, which mainly focus on ReLU-based models, LTE can also be applied to LLMs like LLaMA using non-ReLU activations. Extensive evaluation on language understanding, language generation, and instruction tuning tasks show that LTE consistently outperforms SOTA baselines. Along with our hardware-aware custom kernel implementation, LTE reduces LLaMA2-7B inference latency by 25% at 50% sparsity.
