Table of Contents
Fetching ...

SinkLoRA: Enhanced Efficiency and Chat Capabilities for Long-Context Large Language Models

Hengyu Zhang

TL;DR

LongLoRA proposed shifted sparse attention (S\(^2\)-Attn), effectively enabling context extension and leading to non-trivial computation savings with similar performance to fine-tuning with vanilla attention, but is still not as efficient as vanilla attention.

Abstract

Extending the functionality of the Transformer model to accommodate longer sequence lengths has become a critical challenge. This extension is crucial not only for improving tasks such as language translation and long-context processing but also for enabling novel applications like chatbots, code generation, and multimedia content creation. The primary obstacle is the self-attention mechanism, which scales quadratically with sequence length in terms of computation time and memory requirements. LongLoRA proposed shifted sparse attention (S\(^2\)-Attn), effectively enabling context extension and leading to non-trivial computation savings with similar performance to fine-tuning with vanilla attention. However, LongLoRA is still not as efficient as vanilla attention, reaching only 39\% of the perplexity improvement compared to full attention. This inefficiency is due to the cyclic shift applied within different attention head patterns, causing either chaos in the attention head structure or unnecessary information exchange between token groups. To address these issues, We propose \textbf{SinkLoRA}, which features better work partitioning. Specifically, (1) we developed SF-Attn with a segmentation and reassembly algorithm to proportionally return cyclically shifted groups of attention heads to their un-shifted state together with global attention of "sink attention tokens", achieving 92\% of the perplexity improvement compared to full attention after fine tuning, and (2) applied a SOTA KV cache compression algorithm H$_2$O to accelerate inference. Furthermore, We conducted supervised fine-tuning with SinkLoRA using a self collected LongAlpaca-plus dataset. All our code, models, datasets, and demos are available at \url{https://github.com/Dexter-GT-86/SinkLoRA}.

SinkLoRA: Enhanced Efficiency and Chat Capabilities for Long-Context Large Language Models

TL;DR

LongLoRA proposed shifted sparse attention (S-Attn), effectively enabling context extension and leading to non-trivial computation savings with similar performance to fine-tuning with vanilla attention, but is still not as efficient as vanilla attention.

Abstract

Extending the functionality of the Transformer model to accommodate longer sequence lengths has become a critical challenge. This extension is crucial not only for improving tasks such as language translation and long-context processing but also for enabling novel applications like chatbots, code generation, and multimedia content creation. The primary obstacle is the self-attention mechanism, which scales quadratically with sequence length in terms of computation time and memory requirements. LongLoRA proposed shifted sparse attention (S-Attn), effectively enabling context extension and leading to non-trivial computation savings with similar performance to fine-tuning with vanilla attention. However, LongLoRA is still not as efficient as vanilla attention, reaching only 39\% of the perplexity improvement compared to full attention. This inefficiency is due to the cyclic shift applied within different attention head patterns, causing either chaos in the attention head structure or unnecessary information exchange between token groups. To address these issues, We propose \textbf{SinkLoRA}, which features better work partitioning. Specifically, (1) we developed SF-Attn with a segmentation and reassembly algorithm to proportionally return cyclically shifted groups of attention heads to their un-shifted state together with global attention of "sink attention tokens", achieving 92\% of the perplexity improvement compared to full attention after fine tuning, and (2) applied a SOTA KV cache compression algorithm HO to accelerate inference. Furthermore, We conducted supervised fine-tuning with SinkLoRA using a self collected LongAlpaca-plus dataset. All our code, models, datasets, and demos are available at \url{https://github.com/Dexter-GT-86/SinkLoRA}.
Paper Structure (30 sections, 1 equation, 9 figures, 9 tables, 1 algorithm)

This paper contains 30 sections, 1 equation, 9 figures, 9 tables, 1 algorithm.

Figures (9)

  • Figure 1: Evaluation of SinkLoRA in bridging the accuracy gap between sparse shifted attention and full attention during supervised fine-tuning, while maintaining the memory efficiency of LongLoRA, which utilizes 1.8 times less memory compared to full fine-tuning. Furthermore, SinkLoRA retains the training speed of LongLoRA, being 1.8 times faster than full fine-tuning, due to the implementation of Sink Fixed Attention. The Llama2-7B modelstouvron2023llama are fine-tuned to various context lengths using Flash-Attention2 dao2023flashattention and DeepSpeed stage 2 rasley2020deepspeed, and are evaluated on the proof-pile test set azerbayevproof in terms of perplexity.
  • Figure 2: Overview of the SinkLoRA fine-tuning process, incorporating Sink Fixed Attention (SF-Attn). Panels (a), (b), and (c) depict the procedure to convert Sparse Shifted Attention into Short Window Attention and subsequently into Sink Fixed Attention. This conversion is executed in two stages: reassembly and making the initial tokens global. In addition to optimizing the LoRA weights within linear layers, SinkLoRA also enables training of the embedding and normalization layers, consistent with the methodology employed in LongLoRA.
  • Figure 3: Overview of the SinkLoRA inference process. Unlike LongLoRA, which retains the original standard self-attention during inference, SinkLoRA implements an optional KV cache compression method, H$^2$O zhang2024h2o. This extension enhances inference speed without significantly compromising performance.
  • Figure 4: Illustration of the Segmentation and Reassembly process in SF-Attn. The process involves three steps: (1) Splitting features along the head dimension into two chunks: one shifted and one unshifted. (2) Splitting tokens, where the tokens belonging to the shifted chunk are shifted by half of the group size, and reassembling them at the tail of the tokens to match the unshifted chunk. (3) Combining the two chunks of tokens together. This figure is adapted from chen2023longlora.
  • Figure 5: Accuracy comparison on passkey retrieval between Llama2 7B and our 7B model fine-tuned on a context length of 32,768. Our model shows no retrieval accuracy degradation up to 33k or 36k, surpassing the context length, compared to LongLoRA which is 30k or 34k.
  • ...and 4 more figures