Table of Contents
Fetching ...

FlattenGPT: Depth Compression for Transformer with Layer Flattening

Ruihan Xu, Qingpei Guo, Yao Zhu, Xiangyang Ji, Ming Yang, Shiliang Zhang

TL;DR

FlattenGPT addresses the challenge of redundancy in deep transformer stacks by introducing a two-stage depth-compression scheme: iterative layer flattening to merge adjacent blocks and preserve cross-layer information, followed by channel pruning on the flattened blocks to remove remaining redundancy. The method maintains architectural consistency with the original transformer while significantly accelerating inference, and it demonstrates strong empirical gains, retaining $>90 ext{--}96 op$ of zero-shot performance at a compression of $20 ext{%}$ on models like LLaMA-2/3 and Qwen-1.5, outpacing prior depth-compression and pruning methods on WikiText-2 perplexity and zero-shot benchmarks. By bridging depth compression and channel pruning, FlattenGPT offers a practical, hardware-friendly approach to deploying efficient LLMs without substantial performance loss, with recovery-fine-tuning further enhancing robustness. The work provides a foundation for fine-grained depth reduction in transformers that preserves information across layers and enables easier deployment on resource-constrained hardware.

Abstract

Recent works have indicated redundancy across transformer blocks, prompting the research of depth compression to prune less crucial blocks. However, current ways of entire-block pruning suffer from risks of discarding meaningful cues learned in those blocks, leading to substantial performance degradation. As another line of model compression, channel pruning can better preserve performance, while it cannot reduce model depth and is challenged by inconsistent pruning ratios for individual layers. To pursue better model compression and acceleration, this paper proposes \textbf{FlattenGPT}, a novel way to detect and reduce depth-wise redundancies. By flatting two adjacent blocks into one, it compresses the network depth, meanwhile enables more effective parameter redundancy detection and removal. FlattenGPT allows to preserve the knowledge learned in all blocks, and remains consistent with the original transformer architecture. Extensive experiments demonstrate that FlattenGPT enhances model efficiency with a decent trade-off to performance. It outperforms existing pruning methods in both zero-shot accuracies and WikiText-2 perplexity across various model types and parameter sizes. On LLaMA-2/3 and Qwen-1.5 models, FlattenGPT retains 90-96\% of zero-shot performance with a compression ratio of 20\%. It also outperforms other pruning methods in accelerating LLM inference, making it promising for enhancing the efficiency of transformers.

FlattenGPT: Depth Compression for Transformer with Layer Flattening

TL;DR

FlattenGPT addresses the challenge of redundancy in deep transformer stacks by introducing a two-stage depth-compression scheme: iterative layer flattening to merge adjacent blocks and preserve cross-layer information, followed by channel pruning on the flattened blocks to remove remaining redundancy. The method maintains architectural consistency with the original transformer while significantly accelerating inference, and it demonstrates strong empirical gains, retaining of zero-shot performance at a compression of on models like LLaMA-2/3 and Qwen-1.5, outpacing prior depth-compression and pruning methods on WikiText-2 perplexity and zero-shot benchmarks. By bridging depth compression and channel pruning, FlattenGPT offers a practical, hardware-friendly approach to deploying efficient LLMs without substantial performance loss, with recovery-fine-tuning further enhancing robustness. The work provides a foundation for fine-grained depth reduction in transformers that preserves information across layers and enables easier deployment on resource-constrained hardware.

Abstract

Recent works have indicated redundancy across transformer blocks, prompting the research of depth compression to prune less crucial blocks. However, current ways of entire-block pruning suffer from risks of discarding meaningful cues learned in those blocks, leading to substantial performance degradation. As another line of model compression, channel pruning can better preserve performance, while it cannot reduce model depth and is challenged by inconsistent pruning ratios for individual layers. To pursue better model compression and acceleration, this paper proposes \textbf{FlattenGPT}, a novel way to detect and reduce depth-wise redundancies. By flatting two adjacent blocks into one, it compresses the network depth, meanwhile enables more effective parameter redundancy detection and removal. FlattenGPT allows to preserve the knowledge learned in all blocks, and remains consistent with the original transformer architecture. Extensive experiments demonstrate that FlattenGPT enhances model efficiency with a decent trade-off to performance. It outperforms existing pruning methods in both zero-shot accuracies and WikiText-2 perplexity across various model types and parameter sizes. On LLaMA-2/3 and Qwen-1.5 models, FlattenGPT retains 90-96\% of zero-shot performance with a compression ratio of 20\%. It also outperforms other pruning methods in accelerating LLM inference, making it promising for enhancing the efficiency of transformers.
Paper Structure (21 sections, 3 theorems, 14 equations, 3 figures, 3 tables)

This paper contains 21 sections, 3 theorems, 14 equations, 3 figures, 3 tables.

Key Result

Lemma 2.1

Let $\sigma_{\bm{H}^{\ell}}^{2}$ denote the variance of $\bm{H}^{\ell}$. The variance $\sigma_{\bm{H}^{\ell}}^{2}$ could increase quadratically with depth $\ell$:

Figures (3)

  • Figure 1: Comparison of pruning methods. (a) The original architecture. (b) Layer pruning discards all knowledge in removed blocks. (c) Channel pruning cannot compress model depth and leads to inconsistent architecture across layers. (d) Our method bridges the gap, producing a compact model with marginal performance degradation. (e) compares efficiency, architecture consistency, and performance.
  • Figure 2: Illustration of redundancy in transformer blocks. (a) LLaMA-2 7B exhibits high cross-layer similarity. (b) The scale of the residual path grows faster than the MHA/MLP blocks, which dominates the deep hidden states. (c) The unraveled view of transformer architecture, where the residual path traversing the entire network leads to a considerable cross-layer similarity. (d) The acceleration comparison between different pruning methods.
  • Figure 3: Framework of FlattenGPT, which consists of two stages. (a) Original stacks of transformer blocks with high similarity. (b) Layer flattening merges two adjacent blocks into one single block with a marginal performance degradation. (c) Flattening bridges the gap between depth compression and channel compression.

Theorems & Definitions (3)

  • Lemma 2.1: The growth of the hidden state variance
  • Lemma 2.2: The norm of gradient
  • Lemma 3.1