UniAttn: Reducing Inference Costs via Softmax Unification for Post-Training LLMs
Yizhe Xiong, Wei Huang, Xin Ye, Hui Chen, Zijia Lin, Haoran Lian, Zhenpeng Su, Jungong Han, Guiguang Ding
TL;DR
This paper tackles the latency and memory overhead in post-training LLMs by identifying Softmax activations as a key bottleneck with high cross-layer redundancy. It introduces UniAttn, a Softmax Unification in Attention method that groups transformer blocks into SuperBlocks and reuses a shared Softmax activation, supplemented by a compensating linear projection $W_c$ and a two-stage post-training procedure. Theoretical analysis shows UniAttn preserves the depth-driven capabilities that cross-layer KV sharing can undermine, and extensive experiments across multiple models and datasets demonstrate substantial inference-cost reductions with performance on par with standard post-training. The approach is compatible with intra-layer KV sharing and additional memory-compression techniques, offering a practical path to deployment-ready, inference-efficient post-trained LLMs.
Abstract
Post-training is essential for adapting Large Language Models (LLMs) to real-world applications. Deploying post-trained models faces significant challenges due to substantial memory overhead and noticeable inference latency. Existing work has identified significant redundancies in LLMs and proposed efficient architectures, namely intra-layer KV sharing and cross-layer KV sharing. However, these methods still result in high inference time overhead, remaining suboptimal for post-training pre-trained LLMs. In this paper, we identify that the \texttt{Softmax} operation is a primary bottleneck for LLM inference and discover that it is actually highly redundant during post-training. We propose Softmax \textbf{Uni}fication in \textbf{Att}e\textbf{n}tion (\textbf{UniAttn}), a novel post-training method that unifies Softmax activations across transformer blocks to reduce LLM inference costs. Additionally, UniAttn adopts a linear projection to compensate for the errors induced by Softmax unification. Experiments show that UniAttn matches the performance of standard post-training while significantly reducing inference costs, outperforming existing efficient architectures during post-training.
