Table of Contents
Fetching ...

LeetDecoding: A PyTorch Library for Exponentially Decaying Causal Linear Attention with CUDA Implementations

Jiaping Wang, Simiao Zhang, Qiao-Chu He, Yifan Chen

TL;DR

The paper addresses the quadratic bottleneck of decoder-style attention by focusing on exponentially decaying causal linear attention. It introduces LeetDecoding, a PyTorch/CUDA toolkit that unifies and implements multiple linear-attention algorithms, including a new FleetAttention method that achieves intrinsic linear complexity. The work provides formal complexity analyses (noting the suboptimality of recursive approaches) and advances GPU-friendly implementations (CUDA/Triton) to enable fast inference and rigorous benchmarking on long prompts. Practically, the library enables researchers and practitioners to compare methods, integrate with existing linear transformers, and deploy efficient attention for long-context LLMs, advancing scalable decoding in real-world applications.

Abstract

The machine learning and data science community has made significant while dispersive progress in accelerating transformer-based large language models (LLMs), and one promising approach is to replace the original causal attention in a generative pre-trained transformer (GPT) with \emph{exponentially decaying causal linear attention}. In this paper, we present LeetDecoding, which is the first Python package that provides a large set of computation routines for this fundamental operator. The launch of LeetDecoding was motivated by the current lack of (1) clear understanding of the complexity regarding this operator, (2) a comprehensive collection of existing computation methods (usually spread in seemingly unrelated fields), and (3) CUDA implementations for fast inference on GPU. LeetDecoding's design is easy to integrate with existing linear-attention LLMs, and allows for researchers to benchmark and evaluate new computation methods for exponentially decaying causal linear attention. The usage of LeetDecoding does not require any knowledge of GPU programming and the underlying complexity analysis, intentionally making LeetDecoding accessible to LLM practitioners. The source code of LeetDecoding is provided at \href{https://github.com/Computational-Machine-Intelligence/LeetDecoding}{this GitHub repository}, and users can simply install LeetDecoding by the command \texttt{pip install leet-decoding}.

LeetDecoding: A PyTorch Library for Exponentially Decaying Causal Linear Attention with CUDA Implementations

TL;DR

The paper addresses the quadratic bottleneck of decoder-style attention by focusing on exponentially decaying causal linear attention. It introduces LeetDecoding, a PyTorch/CUDA toolkit that unifies and implements multiple linear-attention algorithms, including a new FleetAttention method that achieves intrinsic linear complexity. The work provides formal complexity analyses (noting the suboptimality of recursive approaches) and advances GPU-friendly implementations (CUDA/Triton) to enable fast inference and rigorous benchmarking on long prompts. Practically, the library enables researchers and practitioners to compare methods, integrate with existing linear transformers, and deploy efficient attention for long-context LLMs, advancing scalable decoding in real-world applications.

Abstract

The machine learning and data science community has made significant while dispersive progress in accelerating transformer-based large language models (LLMs), and one promising approach is to replace the original causal attention in a generative pre-trained transformer (GPT) with \emph{exponentially decaying causal linear attention}. In this paper, we present LeetDecoding, which is the first Python package that provides a large set of computation routines for this fundamental operator. The launch of LeetDecoding was motivated by the current lack of (1) clear understanding of the complexity regarding this operator, (2) a comprehensive collection of existing computation methods (usually spread in seemingly unrelated fields), and (3) CUDA implementations for fast inference on GPU. LeetDecoding's design is easy to integrate with existing linear-attention LLMs, and allows for researchers to benchmark and evaluate new computation methods for exponentially decaying causal linear attention. The usage of LeetDecoding does not require any knowledge of GPU programming and the underlying complexity analysis, intentionally making LeetDecoding accessible to LLM practitioners. The source code of LeetDecoding is provided at \href{https://github.com/Computational-Machine-Intelligence/LeetDecoding}{this GitHub repository}, and users can simply install LeetDecoding by the command \texttt{pip install leet-decoding}.
Paper Structure (35 sections, 1 theorem, 22 equations, 2 figures, 9 tables, 2 algorithms)

This paper contains 35 sections, 1 theorem, 22 equations, 2 figures, 9 tables, 2 algorithms.

Key Result

Lemma 3.1

Consider the matrices $\bm{B}$, $\bm{C} \in \mathbb{R}^{N \times r}$, $\bm{V} \in \mathbb{R}^{N \times d}$, $\bm{\tilde{A}} = \bm{B} \bm{C}^T$, and the causal mask matrix $\bm{M}$. The time complexity of using the recursive computation method to compute $(\tilde{\bm{A}}\odot \bm{M}) \bm{V}$, even eq

Figures (2)

  • Figure 1: Visualization of the Recursion algorithm. Causal attention matrix can be divided into three equal-sized non-zero sections: $\bm{B}_{(1)}\bm{C}_{(1)}\odot \bm{M}$ and $\bm{B}_{(2)}\bm{C}_{(2)} \odot \bm{M}$ are both masked attention score matrices and $\bm{B}_{(2)} \bm{C}_{(1)}$ is an unmasked attention score matrix.
  • Figure 2: Matrices splitting in FleetAttention

Theorems & Definitions (2)

  • Lemma 3.1
  • proof