Table of Contents
Fetching ...

Cross-layer Attention Sharing for Pre-trained Large Language Models

Yongyu Mu, Yuzhang Wu, Yuchun Fan, Chenglong Wang, Hengyu Li, Jiali Zeng, Qiaozhi He, Murun Yang, Fandong Meng, Jie Zhou, Tong Xiao, Jingbo Zhu

TL;DR

This work reveals substantial inter-layer redundancy in self-attention patterns within pre-trained LLMs, with adjacent layers sharing highly similar attention weights. It introduces LiSA, a lightweight sharing mechanism that first aligns attention heads across layers and then compensates residual differences with low-rank projections, enabling shared attention across a majority of layers while preserving accuracy. LiSA achieves significant efficiency gains, including a 6x compression of Q and K matrices and 19.5%-40.1% throughput improvements across multiple models, with only 0.46%-1.64% of parameters being trained. The approach is applicable to existing well-trained LLMs, supports pre-training from scratch and task-specific adaptations, and provides a principled framework for reducing inter-layer attention redundancy with minimal performance loss.

Abstract

To enhance the efficiency of the attention mechanism within large language models (LLMs), previous works primarily compress the KV cache or group attention heads, while largely overlooking redundancy between layers. Our comprehensive analyses across various LLMs show that highly similar attention patterns persist within most layers. It's intuitive to reduce the redundancy by sharing attention weights across layers. However, further analysis reveals two challenges: (1) Directly sharing the weight matrix without carefully rearranging the attention heads proves to be ineffective; (2) Shallow layers are vulnerable to small deviations in attention weights. Driven by these insights, we introduce LISA, a lightweight substitute for self-attention in well-trained LLMs. LISA employs tiny feed-forward networks to align attention heads between adjacent layers and low-rank matrices to approximate differences in layer-wise attention weights. Evaluations encompassing 13 typical benchmarks demonstrate that LISA maintains high response quality in terms of accuracy and perplexity while reducing redundant attention calculations within 53%-84% of the total layers. Our implementations of LISA achieve a 6x compression of Q and K matrices within the attention mechanism, with maximum throughput improvements 19.5%, 32.3%, and 40.1% for LLaMA3-8B, LLaMA2-7B, and LLaMA2-13B, respectively.

Cross-layer Attention Sharing for Pre-trained Large Language Models

TL;DR

This work reveals substantial inter-layer redundancy in self-attention patterns within pre-trained LLMs, with adjacent layers sharing highly similar attention weights. It introduces LiSA, a lightweight sharing mechanism that first aligns attention heads across layers and then compensates residual differences with low-rank projections, enabling shared attention across a majority of layers while preserving accuracy. LiSA achieves significant efficiency gains, including a 6x compression of Q and K matrices and 19.5%-40.1% throughput improvements across multiple models, with only 0.46%-1.64% of parameters being trained. The approach is applicable to existing well-trained LLMs, supports pre-training from scratch and task-specific adaptations, and provides a principled framework for reducing inter-layer attention redundancy with minimal performance loss.

Abstract

To enhance the efficiency of the attention mechanism within large language models (LLMs), previous works primarily compress the KV cache or group attention heads, while largely overlooking redundancy between layers. Our comprehensive analyses across various LLMs show that highly similar attention patterns persist within most layers. It's intuitive to reduce the redundancy by sharing attention weights across layers. However, further analysis reveals two challenges: (1) Directly sharing the weight matrix without carefully rearranging the attention heads proves to be ineffective; (2) Shallow layers are vulnerable to small deviations in attention weights. Driven by these insights, we introduce LISA, a lightweight substitute for self-attention in well-trained LLMs. LISA employs tiny feed-forward networks to align attention heads between adjacent layers and low-rank matrices to approximate differences in layer-wise attention weights. Evaluations encompassing 13 typical benchmarks demonstrate that LISA maintains high response quality in terms of accuracy and perplexity while reducing redundant attention calculations within 53%-84% of the total layers. Our implementations of LISA achieve a 6x compression of Q and K matrices within the attention mechanism, with maximum throughput improvements 19.5%, 32.3%, and 40.1% for LLaMA3-8B, LLaMA2-7B, and LLaMA2-13B, respectively.
Paper Structure (69 sections, 6 equations, 15 figures, 10 tables)

This paper contains 69 sections, 6 equations, 15 figures, 10 tables.

Figures (15)

  • Figure 1: Comparison of different attention models. Layer$_n$ stands for a Transformer layer while $h_1$, $h_2$, and $h_3$ represent three attention heads. Standard attention individually calculates attention scores at each layer by employing $Q_n$ and $K_n$ matrices. Average attention assigns uniform weights across all token positions, thus eliminating $Q$ and $K$ matrices. Directly sharing attention reuses the raw weight matrix from the front layer but overlooks varied head weights across different layers. Our method, LiSA attention, not only aligns attention heads but also compensates for layer-wise weight differences leveraging low-rank $Q_n^{LR}$ and $K_n^{LR}$ matrices, thus maximally preserving the original performance while introducing only a few additional training parameters.
  • Figure 2: An illustration of strategies for measuring the similarity of attention weights across different layers. The attention mechanism in each layer is assumed to have three heads, represented by blue, red, and green colors, corresponding to their positions within the attention weight matrices. We propose two similarity calculation settings: S1 computes the average attention weights across heads first and then calculates similarity via $\text{Sim}(\cdot)$, e.g., JS divergence; S2 calculates pairwise similarity scores between aligned heads individually and then averages the scores. Specifically, three head-alignment strategies are considered: (1) Position-based alignment matches heads according to their positional indices; (2) Random-pair alignment randomly matches heads from two layers; (3) Similarity-based alignment pairs each head with its most similar counterpart in the preceding layer, without enforcing a strict one-to-one correspondence.
  • Figure 3: The JS divergence scores for the attention weights between every pair of layers, calculated under setting S1. For all figures, both the horizontal and vertical coordinates stand for layer indices. A deeper red color indicates a lower JS divergence score, corresponding to higher similarity. For instance, the cell located at the third row and fourth column of the top-left figure indicates that the JS divergence between the third and fourth layers of LLaMA3-8B is less than $0.05$. See Figure \ref{['app:heatmap2']} for results of LLaMA2-7B and Gemma-7B.
  • Figure 4: Figure (a) presents the JS divergence between the attention distributions of two distinct sentences, each with an equal number of tokens, across all pairs of layers in LLaMA3-8B. Figure (b) shows the JS divergence of attention weights for LLaMA3-8B on the PIQA dataset, excluding the first token, which is a special token that receives the majority of the attention.
  • Figure 5: Figure (a) displays the cosine similarity scores for sub-modules within the attention mechanism across each pair of adjacent layers. Figures (b), (c), and (d) present the average JS divergence of attention weights between adjacent layers under three different alignment strategies in setting S2: position-based, random-pair, and similarity-based, respectively. Lines are added to improve the visual clarity of trends between discrete layers, even though the x-axis represents discrete layer indices.
  • ...and 10 more figures