Table of Contents
Fetching ...

Tensor Product Attention Is All You Need

Yifan Zhang, Yifeng Liu, Huizhuo Yuan, Zhen Qin, Yang Yuan, Quanquan Gu, Andrew Chi-Chih Yao

TL;DR

The paper tackles the memory bottleneck of key-value caches in long-context language models by proposing Tensor Product Attention (TPA), which factorizes queries, keys, and values into contextual, low-rank components to dramatically shrink KV caches at inference.TPA is instantiated in the Tensor ProducT ATTenTion Transformer (T6) and complemented by FlashTPA decoding for efficient autoregressive inference, with RoPE compatibility preserved through pre-rotated factor representations.The authors show MHA, MQA, and GQA are special cases of TPA, unify existing attention variants under a single framework, and demonstrate through large-scale pretraining and downstream benchmarks that TPA either matches or improves perplexity and task performance while reducing memory.This approach enables longer context windows under fixed hardware, offering practical scalability for modern LLMs and motivating efficient implementations like FlashTPA in real-world systems.

Abstract

Scaling language models to handle longer input sequences typically necessitates large key-value (KV) caches, resulting in substantial memory overhead during inference. In this paper, we propose Tensor Product Attention (TPA), a novel attention mechanism that uses tensor decompositions to represent queries, keys, and values compactly, substantially shrinking the KV cache size at inference time. By factorizing these representations into contextual low-rank components and seamlessly integrating with Rotary Position Embedding (RoPE), TPA achieves improved model quality alongside memory efficiency. Based on TPA, we introduce the Tensor ProducT ATTenTion Transformer (T6), a new model architecture for sequence modeling. Through extensive empirical evaluation on language modeling tasks, we demonstrate that T6 surpasses or matches the performance of standard Transformer baselines including Multi-Head Attention (MHA), Multi-Query Attention (MQA), Grouped-Query Attention (GQA), and Multi-Head Latent Attention (MLA) across various metrics, including perplexity and a range of established evaluation benchmarks. Notably, TPA's memory efficiency and computational efficiency at decoding stage enables processing longer sequences under fixed resource constraints, addressing a critical scalability challenge in modern language models. Project Page: https://github.com/tensorgi/TPA.

Tensor Product Attention Is All You Need

TL;DR

The paper tackles the memory bottleneck of key-value caches in long-context language models by proposing Tensor Product Attention (TPA), which factorizes queries, keys, and values into contextual, low-rank components to dramatically shrink KV caches at inference.TPA is instantiated in the Tensor ProducT ATTenTion Transformer (T6) and complemented by FlashTPA decoding for efficient autoregressive inference, with RoPE compatibility preserved through pre-rotated factor representations.The authors show MHA, MQA, and GQA are special cases of TPA, unify existing attention variants under a single framework, and demonstrate through large-scale pretraining and downstream benchmarks that TPA either matches or improves perplexity and task performance while reducing memory.This approach enables longer context windows under fixed hardware, offering practical scalability for modern LLMs and motivating efficient implementations like FlashTPA in real-world systems.

Abstract

Scaling language models to handle longer input sequences typically necessitates large key-value (KV) caches, resulting in substantial memory overhead during inference. In this paper, we propose Tensor Product Attention (TPA), a novel attention mechanism that uses tensor decompositions to represent queries, keys, and values compactly, substantially shrinking the KV cache size at inference time. By factorizing these representations into contextual low-rank components and seamlessly integrating with Rotary Position Embedding (RoPE), TPA achieves improved model quality alongside memory efficiency. Based on TPA, we introduce the Tensor ProducT ATTenTion Transformer (T6), a new model architecture for sequence modeling. Through extensive empirical evaluation on language modeling tasks, we demonstrate that T6 surpasses or matches the performance of standard Transformer baselines including Multi-Head Attention (MHA), Multi-Query Attention (MQA), Grouped-Query Attention (GQA), and Multi-Head Latent Attention (MLA) across various metrics, including perplexity and a range of established evaluation benchmarks. Notably, TPA's memory efficiency and computational efficiency at decoding stage enables processing longer sequences under fixed resource constraints, addressing a critical scalability challenge in modern language models. Project Page: https://github.com/tensorgi/TPA.
Paper Structure (54 sections, 2 theorems, 95 equations, 11 figures, 18 tables, 3 algorithms)

This paper contains 54 sections, 2 theorems, 95 equations, 11 figures, 18 tables, 3 algorithms.

Key Result

Theorem 3.1

Let $\mathbf{Q}_t$ be factorized by TPA as where $\mathbf{A}_{Q}(\mathbf{x}_t) \in \mathbb{R}^{R_Q \times h}$ and $\mathbf{B}_{Q}(\mathbf{x}_t) \in \mathbb{R}^{R_Q \times d_h}$. Then we have: where $\widetilde{\mathbf{B}}_{Q}(\mathbf{x}_t) = \operatorname{RoPE}_t\bigl(\mathbf{B}_{Q}(\mathbf{x}_t)\bigr)$ (RoPE applied row-wise to $\mathbf{B}_Q(\mathbf{x}_t)$). Furthermore, let $\mathbf{Q}_t$ and

Figures (11)

  • Figure 1: Tensor Product Attention (TPA) within the Tensor ProducT ATTenTion Transformer (T6). In each TPA layer, the input hidden state $\mathbf{x}_t$ is processed by linear layers to produce latent factor matrices for query (e.g., $\mathbf{A}_Q(\mathbf{x}_t), \mathbf{B}_Q(\mathbf{x}_t)$), key (e.g., $\mathbf{A}_K(\mathbf{x}_t), \mathbf{B}_K(\mathbf{x}_t)$), and value (e.g., $\mathbf{A}_V(\mathbf{x}_t), \mathbf{B}_V(\mathbf{x}_t)$). Rotary Position Embedding (RoPE) is applied to the $\mathbf{B}_Q(\mathbf{x}_t)$ and $\mathbf{B}_K(\mathbf{x}_t)$ factors. The query, key, and value tensors for each attention head are then formed by the tensor product of these factor matrices (e.g., $\mathbf{Q}_t = \frac{1}{R_Q} \mathbf{A}_Q(\mathbf{x}_t)^\top \mathbf{B}_Q(\mathbf{x}_t)$). Finally, the TPA output is computed using scaled dot-product attention, followed by a linear projection of the concatenated results from all heads.
  • Figure 2: Data flow diagram for FlashTPA Decoding. Rectangles represent tensors (blue for inputs, yellow for intermediates, red for final output), circles with $\sum$ or $\odot$ denote Einstein summation contractions or element-wise products respectively, and the green rounded rectangle is the softmax operation. Shapes are shown for a single query ($N=1$) interacting with $M$ cached items. $H$ is the number of heads, $R_Q$ is the query rank, and $D, E$ are respective feature dimensions for the $\mathbf{B}_Q/\bm{b}^K_{\text{cache}}$ and $\bm{b}^V_{\text{cache}}$ factors. Scaling factors in softmax are omitted for visual clarity.
  • Figure 3: The training loss of medium-size (353M), large-size (773M) as well as XL-size (1.5B) models, with different attention mechanisms on the FineWeb-Edu 100B dataset.
  • Figure 4: The validation loss of medium-size (353M), large-size (773M) as well as XL-size (1.5B) models, with different attention mechanisms on the FineWeb-Edu 100B dataset.
  • Figure 5: Decoding time comparison of different attention mechanisms with an embedding dimension of 2048 and $d_h=64$. The y-axis represents $\log_2(\text{time})$ in seconds, and the x-axis represents $\log_2(\text{sequence length})$. Each subfigure corresponds to a different batch size.
  • ...and 6 more figures

Theorems & Definitions (2)

  • Theorem 3.1: RoPE's Compatibility with TPA
  • Theorem C.1: RoPE Compatibility in Higher-Order TPA