Table of Contents
Fetching ...

Enhancing In-Context Learning Performance with just SVD-Based Weight Pruning: A Theoretical Perspective

Xinhao Yao, Xiaolin Hu, Shenzhi Yang, Yong Liu

TL;DR

An exciting phenomenon is shown that SVD-based weight pruning can enhance ICL performance, and more surprising, pruning weights in deep layers often results in more stable performance improvements than in shallow layers.

Abstract

Pre-trained large language models (LLMs) based on Transformer have demonstrated striking in-context learning (ICL) abilities. With a few demonstration input-label pairs, they can predict the label for an unseen input without any parameter updates. In this paper, we show an exciting phenomenon that SVD-based weight pruning can enhance ICL performance, and more surprising, pruning weights in deep layers often results in more stable performance improvements than in shallow layers. However, the underlying mechanism of those findings still remains an open question. To reveal those findings, we conduct an in-depth theoretical analysis by presenting the implicit gradient descent (GD) trajectories of ICL and giving the mutual information based generalization bounds of ICL via full implicit GD trajectories. This helps us reasonably explain the surprising experimental findings. Besides, based on all our experimental and theoretical insights, we intuitively propose a simple, model-compression and derivative-free algorithm for downstream tasks in enhancing ICL inference. Experiments on benchmark datasets and open source LLMs display the method effectiveness\footnote{The code is available at \url{https://github.com/chen123CtrlS/EnhancingICL_SVDPruning}.}.

Enhancing In-Context Learning Performance with just SVD-Based Weight Pruning: A Theoretical Perspective

TL;DR

An exciting phenomenon is shown that SVD-based weight pruning can enhance ICL performance, and more surprising, pruning weights in deep layers often results in more stable performance improvements than in shallow layers.

Abstract

Pre-trained large language models (LLMs) based on Transformer have demonstrated striking in-context learning (ICL) abilities. With a few demonstration input-label pairs, they can predict the label for an unseen input without any parameter updates. In this paper, we show an exciting phenomenon that SVD-based weight pruning can enhance ICL performance, and more surprising, pruning weights in deep layers often results in more stable performance improvements than in shallow layers. However, the underlying mechanism of those findings still remains an open question. To reveal those findings, we conduct an in-depth theoretical analysis by presenting the implicit gradient descent (GD) trajectories of ICL and giving the mutual information based generalization bounds of ICL via full implicit GD trajectories. This helps us reasonably explain the surprising experimental findings. Besides, based on all our experimental and theoretical insights, we intuitively propose a simple, model-compression and derivative-free algorithm for downstream tasks in enhancing ICL inference. Experiments on benchmark datasets and open source LLMs display the method effectiveness\footnote{The code is available at \url{https://github.com/chen123CtrlS/EnhancingICL_SVDPruning}.}.
Paper Structure (31 sections, 8 theorems, 37 equations, 5 figures, 8 tables, 1 algorithm)

This paper contains 31 sections, 8 theorems, 37 equations, 5 figures, 8 tables, 1 algorithm.

Key Result

Lemma 1

Consider a Transformer consists of a single linear layer attention with residual connection, parameterized by $\mathbf{W}_V$,$\mathbf{W}_K$,$\mathbf{W}_Q$ as in Eq.(eq3). Same to Section icl, let $\mathbf{H}=[\mathbf{H}_s,\mathbf{h}_{N+1}]$ be the input, where $\mathbf{H}_s=[\mathbf{h}_1,\mathbf{h}_ where $\mathbf{W}_V\mathbf{H}_s$ is regarded as the meta-gradient of ICL, which is used to generate

Figures (5)

  • Figure 1: The effect of weight pruning across different layer types. The figure shows the phenomenon observed on the benchmark datasets (SST-2, RTE, COPA) and open source LLMs (GPT-J-6B and LLAMA2-7B). Each sub-figure corresponds only to the indicated type of dataset, model and module. Notice that this figure mainly focuses on exhibiting the impact of weight pruning to the first two and the last two layers of the model and different colors are used to distinguish between these layers. The dashed line represents the pretrained model performance without SVD. We operate on the whole of MLP or ATTN and specifically marked the points of highest performance. The amount of weight pruning is severe, for instance, the highest model performance sometimes occurs at a clipping rate of 0.995. This is about 99.5% of the matrix’s original rank. For the definitions of “deep” and “shallow”, please refer to Appendix \ref{['def_deep_']}.
  • Figure 2: The effect of different ICL shot numbers is not uniform. Here we show the effect of different ICL shot numbers on the phenomenon mentioned in Section \ref{['ep11']} as studied on the SST-2 dataset. Each row represents the results of the same shot numbers in different layers and modules, and each column represents the results of the different shot numbers in different layers of the same module. We also specifically marked the points of highest performance.
  • Figure 3: 2-norm condition number of GPTJ-6B&LLAMA2-7B. The condition numbers for MLP are significantly lower than those for ATTN. In deeper layers, condition numbers tend to be higher. These matrices are ill-conditioned for they satisfy $\sigma_{max}\gg \sigma_{min}$.
  • Figure 4: The Model Performance on Test set by different tasks. The results are obtained by comparing four scenarios: ICL (GPT-J-6B), ICL+Algorithm1 (GPT-J-6B), ICL (LLAMA2-7B) and ICL+Algorithm1 (LLAMA2-7B). ICL+Algorithm1 demonstrates superior results over only ICL on different tasks. See Appendix \ref{['result_algorithm']} for detailed numbers.
  • Figure 5: Evaluating the trained Transformer on in-context learning linear functions. (a) garg2023transformers consider the class of linear functions $\mathcal{F} = \{f | f(x) = w^T x, w \in \mathbb{R}^d\}$, in $d$ dimensions where $d = 20$. They sample $x_1, \ldots, x_k, x_{\text{query}},$ and $w$ independently from the isotropic Gaussian distribution $N(0, I_d)$. They then compute each $y_i = w^T x_i$ and construct the prompt as $P = (x_1, y_1, x_2, y_2, \ldots, x_k, y_k, x_{\text{query}})$. This figure plots the normalized squared error of the Transformer $((M(P)-w^Tx_{\text{query}})^2/d)$, the errors are normalized so that the trivial zero estimator achieves an error of 1 (dashed line). Besides, when the number of in-context examples reaches the problem dimension d (here 20), least squares achieves 0 error while the Transformer achieves an error of 0.02. (b) We follow the same setting of garg2023transformers in (a) to compare the Trained Transformer with different train steps.

Theorems & Definitions (21)

  • Lemma 1: The Implicit Gradient Descent of ICL in a Single Linear Attention Layer
  • Remark 1
  • Theorem 1: The Implicit Gradient Descent Trajectories of ICL
  • Remark 2
  • Remark 3
  • Theorem 2: The Generalization Bounds of ICL via Full Implicit Gradient Descent Trajectories
  • Remark 4: Deal with \ref{['q1']}
  • Example 1: Prune $\mathbf{W}_Q,\mathbf{W}_K,\mathbf{W}_V$
  • Remark 5: Deal with \ref{['q2']}
  • Remark 6: How should Theorem 2 be interpreted?
  • ...and 11 more