Table of Contents
Fetching ...

STAT: Shrinking Transformers After Training

Megan Flynn, Alexander Wang, Dean Edward Alvarez, Christopher De Sa, Anil Damle

TL;DR

STAT addresses the problem of compressing transformer models without fine-tuning by pruning attention heads and neurons and compensating with a dense, data-driven correction to the next layer. The method centers on interpolative decompositions (ID) to select components via pivoted QR on intermediary activations, producing an interpolation matrix $T$ that preserves outputs with a provable low-rank-like approximation $A \approx A_{:, \mathcal{I}} T$ and error bound $\|A - A_{:, \mathcal{I}} T\|_2 \le \epsilon \|A\|_2$. Scalability to large models is achieved through randomized projections and sketching (e.g., CountSketch) to compute IDs efficiently, enabling practical pruning of models such as Llama-2 $7B$ on a single GPU within hours. Experiments on BERT, DistilBERT, and Llama-2 demonstrate strong FLOPs/accuracy tradeoffs without fine-tuning, often outperforming retraining-free baselines and approaching heavily finetuned methods, with structured pruning preserved through the use of IDs and corrections. This work offers a hardware-friendly, one-shot compression approach for encoder and decoder transformers, with explicit error control and scalable implementation details for modern large-scale models.

Abstract

We present STAT: a simple algorithm to prune transformer models without any fine-tuning. STAT eliminates both attention heads and neurons from the network, while preserving accuracy by calculating a correction to the weights of the next layer. Each layer block in the network is compressed using a series of principled matrix factorizations that preserve the network structure. Our entire algorithm takes minutes to compress BERT, and less than three hours to compress models with 7B parameters using a single GPU. Using only several hundred data examples, STAT preserves the output of the network and improves upon existing gradient-free pruning methods. It is even competitive with methods that include significant fine-tuning. We demonstrate our method on both encoder and decoder architectures, including BERT, DistilBERT, and Llama-2 using benchmarks such as GLUE, Squad, WikiText2.

STAT: Shrinking Transformers After Training

TL;DR

STAT addresses the problem of compressing transformer models without fine-tuning by pruning attention heads and neurons and compensating with a dense, data-driven correction to the next layer. The method centers on interpolative decompositions (ID) to select components via pivoted QR on intermediary activations, producing an interpolation matrix that preserves outputs with a provable low-rank-like approximation and error bound . Scalability to large models is achieved through randomized projections and sketching (e.g., CountSketch) to compute IDs efficiently, enabling practical pruning of models such as Llama-2 on a single GPU within hours. Experiments on BERT, DistilBERT, and Llama-2 demonstrate strong FLOPs/accuracy tradeoffs without fine-tuning, often outperforming retraining-free baselines and approaching heavily finetuned methods, with structured pruning preserved through the use of IDs and corrections. This work offers a hardware-friendly, one-shot compression approach for encoder and decoder transformers, with explicit error control and scalable implementation details for modern large-scale models.

Abstract

We present STAT: a simple algorithm to prune transformer models without any fine-tuning. STAT eliminates both attention heads and neurons from the network, while preserving accuracy by calculating a correction to the weights of the next layer. Each layer block in the network is compressed using a series of principled matrix factorizations that preserve the network structure. Our entire algorithm takes minutes to compress BERT, and less than three hours to compress models with 7B parameters using a single GPU. Using only several hundred data examples, STAT preserves the output of the network and improves upon existing gradient-free pruning methods. It is even competitive with methods that include significant fine-tuning. We demonstrate our method on both encoder and decoder architectures, including BERT, DistilBERT, and Llama-2 using benchmarks such as GLUE, Squad, WikiText2.
Paper Structure (33 sections, 6 equations, 8 figures, 1 table)

This paper contains 33 sections, 6 equations, 8 figures, 1 table.

Figures (8)

  • Figure 1: Pruning results for BERT and DistilBERT models. We begin with the exact same base models as rtfp. Error bars are reported as standard deviations across 10 trials of random pruning set selections. We see that for modest pruning levels, our method performs within error bounds of the baseline method, and at higher compression ratios we substantially outperform rtfp.
  • Figure 2: Pruning results for BERT model, compared against methods Wang_2020lin2020pruningSajjad_2023hou2020dynabertliu-etal-2021-ebertlagunas2021block which include substantial (at least 4 epochs and 5 hours of) fine tuning (stars) and rtfp, which does not include fine tuning.
  • Figure 3: Left: Accuracy of pruned model on QQP dataset given different amounts of pruning data - We see that accuracy improves until 512 data examples, and then remains fairly steady. Right: Ablation of the two-step head compression process.
  • Figure 4: First 64 normalized singular values of the $K_iQ_i$ matrix for 4 heads randomly selected from the first 4 layers of a network.
  • Figure 5: We prune 3 layers in the network and show the difference in the error induced on the output of the layer on the pruning set. Shifting the norms from the next layer improves the error slightly.
  • ...and 3 more figures

Theorems & Definitions (1)

  • Definition 2.1: Interpolative Decomposition