Table of Contents
Fetching ...

DHA: Learning Decoupled-Head Attention from Transformer Checkpoints via Adaptive Heads Fusion

Yilong Chen, Linhao Zhang, Junyuan Shang, Zhenyu Zhang, Tingwen Liu, Shuohuan Wang, Yu Sun

TL;DR

This work targets the efficiency bottlenecks of Multi-Head Attention (MHA) in large transformers by introducing Decoupled-Head Attention (DHA), learned from transformer checkpoints via Adaptive Heads Fusion. DHA adaptively allocates separate key and value heads across layers, forming fusible head groups and using a linear heads fusion operator to map from a larger MHA to a compact DHA while preserving functionality. The method combines a principled fusion loss with an Augmented Lagrangian constraint, enabling progressive, staged transformations (Search, Fusion, Continued Pre-training) on models like LLaMA, achieving up to 97.6% of original performance with only 0.25% of the pre-training budget and up to 75% KV-cache reduction; it also offers a 5× training acceleration over comparable baselines. Overall, DHA provides a scalable, checkpoint-driven pathway to deploy efficient, high-performing decoupled-head transformers across existing architectures with significant resource savings.

Abstract

Large language models (LLMs) with billions of parameters demonstrate impressive performance. However, the widely used Multi-Head Attention (MHA) in LLMs incurs substantial computational and memory costs during inference. While some efforts have optimized attention mechanisms by pruning heads or sharing parameters among heads, these methods often lead to performance degradation or necessitate substantial continued pre-training costs to restore performance. Based on the analysis of attention redundancy, we design a Decoupled-Head Attention (DHA) mechanism. DHA adaptively configures group sharing for key heads and value heads across various layers, achieving a better balance between performance and efficiency. Inspired by the observation of clustering similar heads, we propose to progressively transform the MHA checkpoint into the DHA model through linear fusion of similar head parameters step by step, retaining the parametric knowledge of the MHA checkpoint. We construct DHA models by transforming various scales of MHA checkpoints given target head budgets. Our experiments show that DHA remarkably requires a mere 0.25\% of the original model's pre-training budgets to achieve 97.6\% of performance while saving 75\% of KV cache. Compared to Group-Query Attention (GQA), DHA achieves a 5$\times$ training acceleration, a maximum of 13.93\% performance improvement under 0.01\% pre-training budget, and 4\% relative improvement under 0.05\% pre-training budget.

DHA: Learning Decoupled-Head Attention from Transformer Checkpoints via Adaptive Heads Fusion

TL;DR

This work targets the efficiency bottlenecks of Multi-Head Attention (MHA) in large transformers by introducing Decoupled-Head Attention (DHA), learned from transformer checkpoints via Adaptive Heads Fusion. DHA adaptively allocates separate key and value heads across layers, forming fusible head groups and using a linear heads fusion operator to map from a larger MHA to a compact DHA while preserving functionality. The method combines a principled fusion loss with an Augmented Lagrangian constraint, enabling progressive, staged transformations (Search, Fusion, Continued Pre-training) on models like LLaMA, achieving up to 97.6% of original performance with only 0.25% of the pre-training budget and up to 75% KV-cache reduction; it also offers a 5× training acceleration over comparable baselines. Overall, DHA provides a scalable, checkpoint-driven pathway to deploy efficient, high-performing decoupled-head transformers across existing architectures with significant resource savings.

Abstract

Large language models (LLMs) with billions of parameters demonstrate impressive performance. However, the widely used Multi-Head Attention (MHA) in LLMs incurs substantial computational and memory costs during inference. While some efforts have optimized attention mechanisms by pruning heads or sharing parameters among heads, these methods often lead to performance degradation or necessitate substantial continued pre-training costs to restore performance. Based on the analysis of attention redundancy, we design a Decoupled-Head Attention (DHA) mechanism. DHA adaptively configures group sharing for key heads and value heads across various layers, achieving a better balance between performance and efficiency. Inspired by the observation of clustering similar heads, we propose to progressively transform the MHA checkpoint into the DHA model through linear fusion of similar head parameters step by step, retaining the parametric knowledge of the MHA checkpoint. We construct DHA models by transforming various scales of MHA checkpoints given target head budgets. Our experiments show that DHA remarkably requires a mere 0.25\% of the original model's pre-training budgets to achieve 97.6\% of performance while saving 75\% of KV cache. Compared to Group-Query Attention (GQA), DHA achieves a 5 training acceleration, a maximum of 13.93\% performance improvement under 0.01\% pre-training budget, and 4\% relative improvement under 0.05\% pre-training budget.
Paper Structure (74 sections, 14 equations, 13 figures, 6 tables, 5 algorithms)

This paper contains 74 sections, 14 equations, 13 figures, 6 tables, 5 algorithms.

Figures (13)

  • Figure 1: Upper: Overview of Decoupled-head method. Multi-Head attention (MHA) has equal query, key and value heads. Grouped-Query attention (GQA) instead shares single key and value heads for each group of query heads. Decoupled-Head attention (DHA) shares key heads and value heads for different groups of query heads in different layers. Lower: GQA Initialization: Heads are mean pooled into a single head; DHA Initialization: DHA search head grouping and progressively fuse heads to maintain parameter functions.
  • Figure 2: Visualization of the similarity between heads within the MHA of LLaMA2-7B model at the 0th layer (a) and the 21st layer (b). Details in Appendix \ref{['app:obsmha']}. Key heads and value heads exhibit decoupled distributions.
  • Figure 3: (a) Model loss with heads proportions in linear fusion. (b) Layer Redundancy of the query, key, value head parameter matrices in the LLaMA2-7B model MHA.
  • Figure 4: An illustration of DHA. First, we reconstruct the a single head forward as a linear combination of multiple heads' forward with proportions $\mathbf{\omega}$, grouping heads with similar functions based on multi-step optimization. Next, we initialize and optimize the fusion operators. $\Leftrightarrow$ indicates the optimization narrows the distance between proportions $\mathbf{\omega}$. Finally, we fuse heads within groups and continued pre-training DHA model.
  • Figure 5: LM Loss with Fusion Training (B) between GQA-7B-25% and DHA-7B-25%.
  • ...and 8 more figures