EcoSpa: Efficient Transformer Training with Coupled Sparsity
Jinqi Xiao, Cheng Luo, Lingyi Huang, Cheng Yang, Yang Sui, Huy Phan, Xiao Zang, Yibiao Ying, Zhexiang Tang, Anima Anandkumar, Bo Yuan
TL;DR
Transformer training is computationally intensive; EcoSpa tackles this by treating coupled weight matrices as first-class citizens and sparsifying them in aligned pairs through coupled estimation and sparsification. It leverages empirical Fisher information to assess the joint importance of matrix pairs and employs a tSVD-inspired, aligned removal strategy to preserve interaction patterns in both MHA and FFN (including extensions to GQA and RoPE). The method yields substantial system-level gains across architectures (LLaMA, GPT-2, DeiT), including memory reductions, faster training, and improved perplexity and throughput, all using standard PyTorch operations on commodity hardware. This makes efficient transformer training more accessible across pre-training and fine-tuning phases while preserving performance.
Abstract
Transformers have become the backbone of modern AI, yet their high computational demands pose critical system challenges. While sparse training offers efficiency gains, existing methods fail to preserve critical structural relationships between weight matrices that interact multiplicatively in attention and feed-forward layers. This oversight leads to performance degradation at high sparsity levels. We introduce EcoSpa, an efficient structured sparse training method that jointly evaluates and sparsifies coupled weight matrix pairs, preserving their interaction patterns through aligned row/column removal. EcoSpa introduces a new granularity for calibrating structural component importance and performs coupled estimation and sparsification across both pre-training and fine-tuning scenarios. Evaluations demonstrate substantial improvements: EcoSpa enables efficient training of LLaMA-1B with 50\% memory reduction and 21\% faster training, achieves $2.2\times$ model compression on GPT-2-Medium with $2.4$ lower perplexity, and delivers $1.6\times$ inference speedup. The approach uses standard PyTorch operations, requiring no custom hardware or kernels, making efficient transformer training accessible on commodity hardware.
