Table of Contents
Fetching ...

Self-Selected Attention Span for Accelerating Large Language Model Inference

Tian Jin, Wanzin Yazar, Zifei Xu, Sayeh Sharify, Xin Wang

TL;DR

This work tackles the inefficiency of autoregressive LLM inference by enabling the model to self-identify the minimal attention span required at each generation step, thereby enabling on-the-fly sparse attention. It builds two annotated datasets for complex arithmetic evaluation and news summarization, fine-tunes LLMs to predict attention spans, and implements a blocked attention-span encoding with a FlashAttention-based CUDA kernel to accelerate decoding. Empirically, the approach achieves up to $28\%$ throughput gains on arithmetic evaluation with no loss in accuracy, while showing mixed results on summarization that improve with more fine-tuning. Overall, the study demonstrates that LLMs can autonomously optimize their own computation at inference time, offering a path toward adaptive, interpretable, and more efficient large-scale inference.

Abstract

Large language models (LLMs) can solve challenging tasks. However, their inference computation on modern GPUs is highly inefficient due to the increasing number of tokens they must attend to as they generate new ones. To address this inefficiency, we capitalize on LLMs' problem-solving capabilities to optimize their own inference-time efficiency. We demonstrate with two specific tasks: (a) evaluating complex arithmetic expressions and (b) summarizing news articles. For both tasks, we create custom datasets to fine-tune an LLM. The goal of fine-tuning is twofold: first, to make the LLM learn to solve the evaluation or summarization task, and second, to train it to identify the minimal attention spans required for each step of the task. As a result, the fine-tuned model is able to convert these self-identified minimal attention spans into sparse attention masks on-the-fly during inference. We develop a custom CUDA kernel to take advantage of the reduced context to attend to. We demonstrate that using this custom CUDA kernel improves the throughput of LLM inference by 28%. Our work presents an end-to-end demonstration showing that training LLMs to self-select their attention spans speeds up autoregressive inference in solving real-world tasks.

Self-Selected Attention Span for Accelerating Large Language Model Inference

TL;DR

This work tackles the inefficiency of autoregressive LLM inference by enabling the model to self-identify the minimal attention span required at each generation step, thereby enabling on-the-fly sparse attention. It builds two annotated datasets for complex arithmetic evaluation and news summarization, fine-tunes LLMs to predict attention spans, and implements a blocked attention-span encoding with a FlashAttention-based CUDA kernel to accelerate decoding. Empirically, the approach achieves up to throughput gains on arithmetic evaluation with no loss in accuracy, while showing mixed results on summarization that improve with more fine-tuning. Overall, the study demonstrates that LLMs can autonomously optimize their own computation at inference time, offering a path toward adaptive, interpretable, and more efficient large-scale inference.

Abstract

Large language models (LLMs) can solve challenging tasks. However, their inference computation on modern GPUs is highly inefficient due to the increasing number of tokens they must attend to as they generate new ones. To address this inefficiency, we capitalize on LLMs' problem-solving capabilities to optimize their own inference-time efficiency. We demonstrate with two specific tasks: (a) evaluating complex arithmetic expressions and (b) summarizing news articles. For both tasks, we create custom datasets to fine-tune an LLM. The goal of fine-tuning is twofold: first, to make the LLM learn to solve the evaluation or summarization task, and second, to train it to identify the minimal attention spans required for each step of the task. As a result, the fine-tuned model is able to convert these self-identified minimal attention spans into sparse attention masks on-the-fly during inference. We develop a custom CUDA kernel to take advantage of the reduced context to attend to. We demonstrate that using this custom CUDA kernel improves the throughput of LLM inference by 28%. Our work presents an end-to-end demonstration showing that training LLMs to self-select their attention spans speeds up autoregressive inference in solving real-world tasks.
Paper Structure (27 sections, 5 equations, 9 figures, 1 table, 1 algorithm)

This paper contains 27 sections, 5 equations, 9 figures, 1 table, 1 algorithm.

Figures (9)

  • Figure 1: Human thought process is inherently sparse, as shown by the minimal dependencies in the attention matrix below.
  • Figure 2: Runtime breakdown for decoder execution during inference by LLaMA-7B model on an A100 GPU. At $2048$ sequence length, $50$% attention sparsity yields a maximum runtime reduction of $29.7$%, assuming linear speedup w.r.t. sparsity..
  • Figure 3: Illustration of Autoregressive Inference with Reduced Attention Span. In step (b) and (d), the LLM attends to all tokens to select a subset of important tokens for next token prediction. We highlight the selected tokens in green. During generation step (c) and (e), the LLM attends only to the selected subset of tokens.
  • Figure 4: Illustration of attention span encoding. Numbers on a yellow background are block indices. Numbers on gray background represent the binary mask indicating whether each token within the block should be attended to.
  • Figure 5: Example model output for complex arithmetics. Numbers on a blue background are anchors. Numbers on a green background reference these anchors. [-1] denotes a reference to the previous line.
  • ...and 4 more figures