Beyond Linear Approximations: A Novel Pruning Approach for Attention Matrix
Yingyu Liang, Jiangxuan Long, Zhenmei Shi, Zhao Song, Yufa Zhou
TL;DR
This work tackles pruning of large language models by directly optimizing the attention weights to approximate the non-linear Softmax attention, addressing the shortcomings of linear pruning on $XW$ mappings. It develops a Gradient Descent–based pruning method with closed-form gradient, proves Lipschitz continuity and a Polyak–Łojasiewicz (PL) inequality–driven convergence, and provides a finite-time guarantee for near-optimal pruning masks. Empirically, the approach preserves model performance while achieving substantial sparsity and reduced computation cost, outperforming state-of-the-art baselines like SparseGPT and Wanda on synthetic data and real models (e.g., Llama 3.2-1B with C4 data), including end-to-end perplexity gains. The results establish a theoretical and practical foundation for attention-focused pruning, with potential to enable efficient LLM inference on edge devices.
Abstract
Large Language Models (LLMs) have shown immense potential in enhancing various aspects of our daily lives, from conversational AI to search and AI assistants. However, their growing capabilities come at the cost of extremely large model sizes, making deployment on edge devices challenging due to memory and computational constraints. This paper introduces a novel approach to LLM weight pruning that directly optimizes for approximating the attention matrix, a core component of transformer architectures. Unlike existing methods that focus on linear approximations, our approach accounts for the non-linear nature of the Softmax attention mechanism. We provide theoretical guarantees for the convergence of our Gradient Descent-based optimization method to a near-optimal pruning mask solution. Our empirical results demonstrate the effectiveness of our non-linear pruning approach in maintaining model performance while significantly reducing computational costs, which is beyond the current state-of-the-art methods, i.e., SparseGPT and Wanda, by a large margin. This work establishes a new theoretical foundation for pruning algorithm design in LLMs, potentially paving the way for more efficient LLM inference on resource-constrained devices.
