Table of Contents
Fetching ...

FlashEVA: Accelerating LLM inference via Efficient Attention

Juan Gabriel Kostelec, Qinghai Guo

TL;DR

FlashEVA tackles the memory and throughput bottlenecks of Transformer inference by reworking EVA attention into a Softmax-attention framework over an augmented key/value set, enabling the use of optimized FlashAttention kernels. By finetuning pretrained Transformers with as few as 1.5B tokens, it achieves substantial throughput gains (up to 6.7x) and peak memory reductions (up to 5x) while largely preserving downstream performance on general tasks. Retrieval-focused tasks remain challenging due to the absence of sliding-window dynamics in the base FlashEVA, though a sliding-window variant and hybrid approaches show potential. The approach is hardware- and kernel-friendly (CUDA/Triton), supports mixing with KV-cache compression, and offers tunable trade-offs between throughput and accuracy, making it practical for long-context and high-throughput scenarios in real-world deployments.

Abstract

Transformer models have revolutionized natural language processing, achieving state-of-the-art performance and demonstrating remarkable scalability. However, their memory demands, particularly due to maintaining full context in memory, pose significant challenges for inference. In this paper, we present FlashEVA, an efficient implementation of EVA (Efficient Attention via Control Variates), and demonstrate how to finetune transformers to adapt to FlashEVA attention. Our method enables fine-tuning of Transformer models with as few as 1.5B tokens while preserving effectiveness across various downstream tasks. Notably, FlashEVA achieves up to 6.7x higher throughput and 5x lower peak GPU memory usage during inference compared to standard Transformer implementations. Despite these improvements, we observe limitations in retrieval-focused tasks. Our implementation offers control over the trade-off between throughput and accuracy through adjustable hyperparameters, providing flexibility for diverse use cases. This work represents a significant step towards more efficient and adaptable Transformer-based models for inference.

FlashEVA: Accelerating LLM inference via Efficient Attention

TL;DR

FlashEVA tackles the memory and throughput bottlenecks of Transformer inference by reworking EVA attention into a Softmax-attention framework over an augmented key/value set, enabling the use of optimized FlashAttention kernels. By finetuning pretrained Transformers with as few as 1.5B tokens, it achieves substantial throughput gains (up to 6.7x) and peak memory reductions (up to 5x) while largely preserving downstream performance on general tasks. Retrieval-focused tasks remain challenging due to the absence of sliding-window dynamics in the base FlashEVA, though a sliding-window variant and hybrid approaches show potential. The approach is hardware- and kernel-friendly (CUDA/Triton), supports mixing with KV-cache compression, and offers tunable trade-offs between throughput and accuracy, making it practical for long-context and high-throughput scenarios in real-world deployments.

Abstract

Transformer models have revolutionized natural language processing, achieving state-of-the-art performance and demonstrating remarkable scalability. However, their memory demands, particularly due to maintaining full context in memory, pose significant challenges for inference. In this paper, we present FlashEVA, an efficient implementation of EVA (Efficient Attention via Control Variates), and demonstrate how to finetune transformers to adapt to FlashEVA attention. Our method enables fine-tuning of Transformer models with as few as 1.5B tokens while preserving effectiveness across various downstream tasks. Notably, FlashEVA achieves up to 6.7x higher throughput and 5x lower peak GPU memory usage during inference compared to standard Transformer implementations. Despite these improvements, we observe limitations in retrieval-focused tasks. Our implementation offers control over the trade-off between throughput and accuracy through adjustable hyperparameters, providing flexibility for diverse use cases. This work represents a significant step towards more efficient and adaptable Transformer-based models for inference.

Paper Structure

This paper contains 27 sections, 14 equations, 8 figures, 6 tables.

Figures (8)

  • Figure 1: Comparative analysis of maximum throughput and peak GPU memory usage across different attention mechanisms and generation lengths.
  • Figure 2: Performance trade-offs between average downstream accuracy and generation throughput or peak GPU memory for various FlashEVA-410M configurations. Each point represents a unique combination of local attention window size and RFA chunk size.
  • Figure 3: Comparison of forward and backward pass execution times for (Flash)EVA attention layer and FlashAttention2, including both causal and non-causal attention variants. Results are reported for a constant number of chunks $C$ across different sequence lengths, varying chunk size accordingly. Additional results with fixed chunk size and varying $C$ are presented in the Appendix, yielding qualitatively similar outcomes.
  • Figure 4: Comparison of training loss and gradient norm during finetuning of CausalEVA with local attention and sliding window attention. The sliding window variant achieves marginally lower loss for the 70M model, but exhibits more unstable gradients with frequent large spikes in norm. The local attention variant demonstrates occasional loss spikes, potentially attributable to the random weight sampling in EVA attention.
  • Figure 5: Comparison of training loss trajectories for models initialized with clipped and unclipped weight distributions, demonstrating the stabilizing effect of our proposed sampling method.
  • ...and 3 more figures