Table of Contents
Fetching ...

Correlation-Aware Select and Merge Attention for Efficient Fine-Tuning and Context Length Extension

Ning Wang, Zekun Li, Tongxin Bai, Guoqi Li

TL;DR

This paper proposes an efficient and flexible attention architecture that enables the extension of context lengths in large language models with reduced computational resources and fine-tuning time compared to other excellent methods, and introduces correlation-aware selection and merging mechanisms to facilitate efficient sparse attention.

Abstract

Modeling long sequences is crucial for various large-scale models; however, extending existing architectures to handle longer sequences presents significant technical and resource challenges. In this paper, we propose an efficient and flexible attention architecture that enables the extension of context lengths in large language models with reduced computational resources and fine-tuning time compared to other excellent methods. Specifically, we introduce correlation-aware selection and merging mechanisms to facilitate efficient sparse attention. In addition, we also propose a novel data augmentation technique involving positional encodings to enhance generalization to unseen positions. The results are as follows: First, using a single A100, we achieve fine-tuning on Llama2-7B with a sequence length of 32K, which is more efficient than other methods that rely on subsets for regression. Second, we present a comprehensive method for extending context lengths across the pre-training, fine-tuning, and inference phases. During pre-training, our attention mechanism partially breaks translation invariance during token selection, so we apply positional encodings only to the selected tokens. This approach achieves relatively high performance and significant extrapolation capabilities. For fine-tuning, we introduce Cyclic, Randomly Truncated, and Dynamically Growing NTK Positional Embedding (CRD NTK). This design allows fine-tuning with a sequence length of only 16K, enabling models such as Llama2-7B and Mistral-7B to perform inference with context lengths of up to 1M or even arbitrary lengths. Our method achieves 100\% accuracy on the passkey task with a context length of 4M and maintains stable perplexity at a 1M context length. This represents at least a 64-fold reduction in resource requirements compared to traditional full-attention mechanisms, while still achieving competitive performance.

Correlation-Aware Select and Merge Attention for Efficient Fine-Tuning and Context Length Extension

TL;DR

This paper proposes an efficient and flexible attention architecture that enables the extension of context lengths in large language models with reduced computational resources and fine-tuning time compared to other excellent methods, and introduces correlation-aware selection and merging mechanisms to facilitate efficient sparse attention.

Abstract

Modeling long sequences is crucial for various large-scale models; however, extending existing architectures to handle longer sequences presents significant technical and resource challenges. In this paper, we propose an efficient and flexible attention architecture that enables the extension of context lengths in large language models with reduced computational resources and fine-tuning time compared to other excellent methods. Specifically, we introduce correlation-aware selection and merging mechanisms to facilitate efficient sparse attention. In addition, we also propose a novel data augmentation technique involving positional encodings to enhance generalization to unseen positions. The results are as follows: First, using a single A100, we achieve fine-tuning on Llama2-7B with a sequence length of 32K, which is more efficient than other methods that rely on subsets for regression. Second, we present a comprehensive method for extending context lengths across the pre-training, fine-tuning, and inference phases. During pre-training, our attention mechanism partially breaks translation invariance during token selection, so we apply positional encodings only to the selected tokens. This approach achieves relatively high performance and significant extrapolation capabilities. For fine-tuning, we introduce Cyclic, Randomly Truncated, and Dynamically Growing NTK Positional Embedding (CRD NTK). This design allows fine-tuning with a sequence length of only 16K, enabling models such as Llama2-7B and Mistral-7B to perform inference with context lengths of up to 1M or even arbitrary lengths. Our method achieves 100\% accuracy on the passkey task with a context length of 4M and maintains stable perplexity at a 1M context length. This represents at least a 64-fold reduction in resource requirements compared to traditional full-attention mechanisms, while still achieving competitive performance.
Paper Structure (46 sections, 61 equations, 3 figures, 14 tables, 1 algorithm)

This paper contains 46 sections, 61 equations, 3 figures, 14 tables, 1 algorithm.

Figures (3)

  • Figure 1: Overview of Efficient Fine-Tuning and Context Length Extension. Efficient fine-tuning is achieved using the Merge and Select Attention Mechanism (MS Attention) combined with LoRA. For length extension during fine-tuning, we build upon this efficient approach by employing Cyclic, Randomly Truncated, and Dynamically Growing NTK Positional Embedding (CRD NTK) for high-order extrapolation. The pre-training method capable of high-order length extrapolation utilizes MS Attention along with CRD NTK. Finally, for direct inference without fine-tuning, we adopt the methodology from InfLLM.
  • Figure 2: Overview of Merging and Selection Attention Mechanism (MS Attention). The MS Attention mechanism involves two main steps. In the first step, the QKV tensors are split into regions, and a single token is used to represent each region. Subsequently, the regional representatives are used to compute dot products or other similarity measures to select the most relevant KV regions for each Q region. For example, Q regions 5, 6, 7 and 8 select KV regions $\{3, 4, 2, 5\}$, $\{5, 4, 2, 6\}$, $\{3, 5, 4, 7\}$ and $\{5, 4, 7, 8\}$ respectively. In the second step, for each Q region, tokens are merged with their adjacent or related tokens after permuting. The union of the selected KV regions is taken, and the top-n regions are chosen. For example, combine Q regions 6 and 8, along with their selected KV regions, resulting in $\{5, 4, 2, 7, 6, 8\}$. To ensure tensor consistency, we select the top $k$ regions from the merged set. If $k = 4$, the final selection is $\{5, 4, 6, 8\}$. The reason for retaining 6 and 8 is because by default we believe that the local region is important and must be preserved, and that the local region does not perform scoring calculations. Finally, each merged Q region performs self-attention with its selected relevant KV regions.
  • Figure 3: Sequence Length Extension Using Recursive Methods. By combining multi-scale MS Attention with recursive methods, we extend the sequence length. The model is fine-tuned on 16K length sequences, and during evaluation, lengths less than 16K are interpolated to 16K positions, while lengths greater than 16K use 32K position interpolation.