FlattenGPT: Depth Compression for Transformer with Layer Flattening
Ruihan Xu, Qingpei Guo, Yao Zhu, Xiangyang Ji, Ming Yang, Shiliang Zhang
TL;DR
FlattenGPT addresses the challenge of redundancy in deep transformer stacks by introducing a two-stage depth-compression scheme: iterative layer flattening to merge adjacent blocks and preserve cross-layer information, followed by channel pruning on the flattened blocks to remove remaining redundancy. The method maintains architectural consistency with the original transformer while significantly accelerating inference, and it demonstrates strong empirical gains, retaining $>90 ext{--}96 op$ of zero-shot performance at a compression of $20 ext{%}$ on models like LLaMA-2/3 and Qwen-1.5, outpacing prior depth-compression and pruning methods on WikiText-2 perplexity and zero-shot benchmarks. By bridging depth compression and channel pruning, FlattenGPT offers a practical, hardware-friendly approach to deploying efficient LLMs without substantial performance loss, with recovery-fine-tuning further enhancing robustness. The work provides a foundation for fine-grained depth reduction in transformers that preserves information across layers and enables easier deployment on resource-constrained hardware.
Abstract
Recent works have indicated redundancy across transformer blocks, prompting the research of depth compression to prune less crucial blocks. However, current ways of entire-block pruning suffer from risks of discarding meaningful cues learned in those blocks, leading to substantial performance degradation. As another line of model compression, channel pruning can better preserve performance, while it cannot reduce model depth and is challenged by inconsistent pruning ratios for individual layers. To pursue better model compression and acceleration, this paper proposes \textbf{FlattenGPT}, a novel way to detect and reduce depth-wise redundancies. By flatting two adjacent blocks into one, it compresses the network depth, meanwhile enables more effective parameter redundancy detection and removal. FlattenGPT allows to preserve the knowledge learned in all blocks, and remains consistent with the original transformer architecture. Extensive experiments demonstrate that FlattenGPT enhances model efficiency with a decent trade-off to performance. It outperforms existing pruning methods in both zero-shot accuracies and WikiText-2 perplexity across various model types and parameter sizes. On LLaMA-2/3 and Qwen-1.5 models, FlattenGPT retains 90-96\% of zero-shot performance with a compression ratio of 20\%. It also outperforms other pruning methods in accelerating LLM inference, making it promising for enhancing the efficiency of transformers.
