STAT: Shrinking Transformers After Training
Megan Flynn, Alexander Wang, Dean Edward Alvarez, Christopher De Sa, Anil Damle
TL;DR
STAT addresses the problem of compressing transformer models without fine-tuning by pruning attention heads and neurons and compensating with a dense, data-driven correction to the next layer. The method centers on interpolative decompositions (ID) to select components via pivoted QR on intermediary activations, producing an interpolation matrix $T$ that preserves outputs with a provable low-rank-like approximation $A \approx A_{:, \mathcal{I}} T$ and error bound $\|A - A_{:, \mathcal{I}} T\|_2 \le \epsilon \|A\|_2$. Scalability to large models is achieved through randomized projections and sketching (e.g., CountSketch) to compute IDs efficiently, enabling practical pruning of models such as Llama-2 $7B$ on a single GPU within hours. Experiments on BERT, DistilBERT, and Llama-2 demonstrate strong FLOPs/accuracy tradeoffs without fine-tuning, often outperforming retraining-free baselines and approaching heavily finetuned methods, with structured pruning preserved through the use of IDs and corrections. This work offers a hardware-friendly, one-shot compression approach for encoder and decoder transformers, with explicit error control and scalable implementation details for modern large-scale models.
Abstract
We present STAT: a simple algorithm to prune transformer models without any fine-tuning. STAT eliminates both attention heads and neurons from the network, while preserving accuracy by calculating a correction to the weights of the next layer. Each layer block in the network is compressed using a series of principled matrix factorizations that preserve the network structure. Our entire algorithm takes minutes to compress BERT, and less than three hours to compress models with 7B parameters using a single GPU. Using only several hundred data examples, STAT preserves the output of the network and improves upon existing gradient-free pruning methods. It is even competitive with methods that include significant fine-tuning. We demonstrate our method on both encoder and decoder architectures, including BERT, DistilBERT, and Llama-2 using benchmarks such as GLUE, Squad, WikiText2.
