Table of Contents
Fetching ...

UniAttn: Reducing Inference Costs via Softmax Unification for Post-Training LLMs

Yizhe Xiong, Wei Huang, Xin Ye, Hui Chen, Zijia Lin, Haoran Lian, Zhenpeng Su, Jungong Han, Guiguang Ding

TL;DR

This paper tackles the latency and memory overhead in post-training LLMs by identifying Softmax activations as a key bottleneck with high cross-layer redundancy. It introduces UniAttn, a Softmax Unification in Attention method that groups transformer blocks into SuperBlocks and reuses a shared Softmax activation, supplemented by a compensating linear projection $W_c$ and a two-stage post-training procedure. Theoretical analysis shows UniAttn preserves the depth-driven capabilities that cross-layer KV sharing can undermine, and extensive experiments across multiple models and datasets demonstrate substantial inference-cost reductions with performance on par with standard post-training. The approach is compatible with intra-layer KV sharing and additional memory-compression techniques, offering a practical path to deployment-ready, inference-efficient post-trained LLMs.

Abstract

Post-training is essential for adapting Large Language Models (LLMs) to real-world applications. Deploying post-trained models faces significant challenges due to substantial memory overhead and noticeable inference latency. Existing work has identified significant redundancies in LLMs and proposed efficient architectures, namely intra-layer KV sharing and cross-layer KV sharing. However, these methods still result in high inference time overhead, remaining suboptimal for post-training pre-trained LLMs. In this paper, we identify that the \texttt{Softmax} operation is a primary bottleneck for LLM inference and discover that it is actually highly redundant during post-training. We propose Softmax \textbf{Uni}fication in \textbf{Att}e\textbf{n}tion (\textbf{UniAttn}), a novel post-training method that unifies Softmax activations across transformer blocks to reduce LLM inference costs. Additionally, UniAttn adopts a linear projection to compensate for the errors induced by Softmax unification. Experiments show that UniAttn matches the performance of standard post-training while significantly reducing inference costs, outperforming existing efficient architectures during post-training.

UniAttn: Reducing Inference Costs via Softmax Unification for Post-Training LLMs

TL;DR

This paper tackles the latency and memory overhead in post-training LLMs by identifying Softmax activations as a key bottleneck with high cross-layer redundancy. It introduces UniAttn, a Softmax Unification in Attention method that groups transformer blocks into SuperBlocks and reuses a shared Softmax activation, supplemented by a compensating linear projection and a two-stage post-training procedure. Theoretical analysis shows UniAttn preserves the depth-driven capabilities that cross-layer KV sharing can undermine, and extensive experiments across multiple models and datasets demonstrate substantial inference-cost reductions with performance on par with standard post-training. The approach is compatible with intra-layer KV sharing and additional memory-compression techniques, offering a practical path to deployment-ready, inference-efficient post-trained LLMs.

Abstract

Post-training is essential for adapting Large Language Models (LLMs) to real-world applications. Deploying post-trained models faces significant challenges due to substantial memory overhead and noticeable inference latency. Existing work has identified significant redundancies in LLMs and proposed efficient architectures, namely intra-layer KV sharing and cross-layer KV sharing. However, these methods still result in high inference time overhead, remaining suboptimal for post-training pre-trained LLMs. In this paper, we identify that the \texttt{Softmax} operation is a primary bottleneck for LLM inference and discover that it is actually highly redundant during post-training. We propose Softmax \textbf{Uni}fication in \textbf{Att}e\textbf{n}tion (\textbf{UniAttn}), a novel post-training method that unifies Softmax activations across transformer blocks to reduce LLM inference costs. Additionally, UniAttn adopts a linear projection to compensate for the errors induced by Softmax unification. Experiments show that UniAttn matches the performance of standard post-training while significantly reducing inference costs, outperforming existing efficient architectures during post-training.

Paper Structure

This paper contains 25 sections, 12 theorems, 41 equations, 7 figures, 13 tables, 1 algorithm.

Key Result

Theorem 3.1

The initialization for $W_{c}$ that incurs minimal error when compensating $\mathbb{E}(\epsilon)$ satisfies: where $U\Sigma V^T$ denotes the SVD decomposition of $\mathbb{E}(\mathbf{x}_{i+b})$, $\Sigma^{+}$ denotes the pseudoinverse of $\Sigma$.

Figures (7)

  • Figure 1: Comparisons of our UniAttn and directly applying cross-layer KV sharing (CLA) during post-training. "A-X" represents modifying total of X layers when applying A.
  • Figure 2: Cosine similarity results of average Softmax activations. Across all settings, the average Softmax activations of the top half of layers (i.e., the right half of each heatmap bar) share a high cosine similarity.
  • Figure 3: Pipeline comparison between standard decoder-based transformers, CLA CLA (block size 3), and UniAttn (Superblock size 3). UniAttn shares the Softmax activations across layers in grouped Superblocks and adds linear transformation $W_c$ to compensate for the unification error. For simplicity, only the self-attention calculation is illustrated for CLA and UniAttn.
  • Figure 4: Comparison between baselines, UniAttn, and UniAttn without ("w/o") linear compensation.
  • Figure 5: Hyperparameter analysis results on average accuracy (%) and TTFT latency (s). Left: Different total numbers of grouped Superblocks (each with size 4). Right: Different Superblock sizes (total of 12 layers that utilize unified Softmax activation).
  • ...and 2 more figures

Theorems & Definitions (22)

  • Theorem 3.1
  • Proposition 3.3
  • Proposition 3.4
  • Lemma 5.1
  • proof
  • Proposition 5.2
  • Proposition 5.3
  • Proposition 5.4
  • proof
  • proof
  • ...and 12 more