Table of Contents
Fetching ...

Beyond Linear Approximations: A Novel Pruning Approach for Attention Matrix

Yingyu Liang, Jiangxuan Long, Zhenmei Shi, Zhao Song, Yufa Zhou

TL;DR

This work tackles pruning of large language models by directly optimizing the attention weights to approximate the non-linear Softmax attention, addressing the shortcomings of linear pruning on $XW$ mappings. It develops a Gradient Descent–based pruning method with closed-form gradient, proves Lipschitz continuity and a Polyak–Łojasiewicz (PL) inequality–driven convergence, and provides a finite-time guarantee for near-optimal pruning masks. Empirically, the approach preserves model performance while achieving substantial sparsity and reduced computation cost, outperforming state-of-the-art baselines like SparseGPT and Wanda on synthetic data and real models (e.g., Llama 3.2-1B with C4 data), including end-to-end perplexity gains. The results establish a theoretical and practical foundation for attention-focused pruning, with potential to enable efficient LLM inference on edge devices.

Abstract

Large Language Models (LLMs) have shown immense potential in enhancing various aspects of our daily lives, from conversational AI to search and AI assistants. However, their growing capabilities come at the cost of extremely large model sizes, making deployment on edge devices challenging due to memory and computational constraints. This paper introduces a novel approach to LLM weight pruning that directly optimizes for approximating the attention matrix, a core component of transformer architectures. Unlike existing methods that focus on linear approximations, our approach accounts for the non-linear nature of the Softmax attention mechanism. We provide theoretical guarantees for the convergence of our Gradient Descent-based optimization method to a near-optimal pruning mask solution. Our empirical results demonstrate the effectiveness of our non-linear pruning approach in maintaining model performance while significantly reducing computational costs, which is beyond the current state-of-the-art methods, i.e., SparseGPT and Wanda, by a large margin. This work establishes a new theoretical foundation for pruning algorithm design in LLMs, potentially paving the way for more efficient LLM inference on resource-constrained devices.

Beyond Linear Approximations: A Novel Pruning Approach for Attention Matrix

TL;DR

This work tackles pruning of large language models by directly optimizing the attention weights to approximate the non-linear Softmax attention, addressing the shortcomings of linear pruning on mappings. It develops a Gradient Descent–based pruning method with closed-form gradient, proves Lipschitz continuity and a Polyak–Łojasiewicz (PL) inequality–driven convergence, and provides a finite-time guarantee for near-optimal pruning masks. Empirically, the approach preserves model performance while achieving substantial sparsity and reduced computation cost, outperforming state-of-the-art baselines like SparseGPT and Wanda on synthetic data and real models (e.g., Llama 3.2-1B with C4 data), including end-to-end perplexity gains. The results establish a theoretical and practical foundation for attention-focused pruning, with potential to enable efficient LLM inference on edge devices.

Abstract

Large Language Models (LLMs) have shown immense potential in enhancing various aspects of our daily lives, from conversational AI to search and AI assistants. However, their growing capabilities come at the cost of extremely large model sizes, making deployment on edge devices challenging due to memory and computational constraints. This paper introduces a novel approach to LLM weight pruning that directly optimizes for approximating the attention matrix, a core component of transformer architectures. Unlike existing methods that focus on linear approximations, our approach accounts for the non-linear nature of the Softmax attention mechanism. We provide theoretical guarantees for the convergence of our Gradient Descent-based optimization method to a near-optimal pruning mask solution. Our empirical results demonstrate the effectiveness of our non-linear pruning approach in maintaining model performance while significantly reducing computational costs, which is beyond the current state-of-the-art methods, i.e., SparseGPT and Wanda, by a large margin. This work establishes a new theoretical foundation for pruning algorithm design in LLMs, potentially paving the way for more efficient LLM inference on resource-constrained devices.

Paper Structure

This paper contains 55 sections, 38 theorems, 138 equations, 3 figures, 1 table, 1 algorithm.

Key Result

Theorem 1.3

For any $\epsilon > 0$, our Algorithm alg:mask_gd can converge to the near-optimal pruning mask for the Attention Weights Pruning problem (Definition def:attn_prun) in $O(d \mathop{\mathrm{poly}}\nolimits(n)/\epsilon)$ time with $O(\xi + \epsilon)$ error, where $\xi$ is a small term depending on int

Figures (3)

  • Figure 1: Comparison of our Attention Weights Pruning method and Linear Pruning method such as Wanda and SparseGPT. The top figure illustrates our proposed method of the attention matrix approximation, where pruning is applied directly to the fused attention weight matrix $W$, using only one pruning mask $M$. The bottom figure describes the Linear Pruning method of the linear function approximation, where pruning is applied separately to the query weight matrix $W_Q$ and key weight matrix $W_K$, using two different pruning masks $M_Q$ and $M_K$, respectively.
  • Figure 2: The comparison among our Algorithm \ref{['alg:mask_gd']}, Wanda, and SparseGPT. The $y$-axis is a relative error, which is defined as $\frac{\|\widetilde{D}^{-1}\widetilde{A} - D^{-1}A\|_F^2}{\|D^{-1}A\|_F^2}$, where $D^{-1}A$ is original attention matrix and $\widetilde{D}^{-1}\widetilde{A}$ is approximated attention matrix based on three methods. We always use $d=64$. We use $k=16$ for the first row and $k=64$ for the second row. The $x$-axis is (a) regularization coefficient $\lambda$ for the left column; (b) input sequence length $n$ for the middle column; (c) pruning ratio $\rho$ for the right column.
  • Figure 3: The comparison among our algorithm, Wanda, and SparseGPT on Llama 3.2-1B.

Theorems & Definitions (87)

  • Definition 1.1: Attention Matrix
  • Definition 1.2: Attention Weights Pruning
  • Theorem 1.3: Main result, informal version of Theorem \ref{['thm:our_convergence']}
  • Definition 3.1: Causal attention mask, lss+24
  • Definition 3.2: Attention Weights Pruning with Causal Attention Mask
  • Theorem 4.1: Main result, formal version of Theorem \ref{['thm:our_convergence:informal']}
  • proof
  • Remark 4.2
  • Definition 5.1: $g$-proxy, $\xi$-optimal PL inequality, Definition 1.2 in fg21
  • Theorem 5.2: Theorem 3.1 in fg21
  • ...and 77 more