Table of Contents
Fetching ...

Long Sequence Modeling with Attention Tensorization: From Sequence to Tensor Learning

Aosong Feng, Rex Ying, Leandros Tassiulas

TL;DR

Long-context processing in LLMs is challenged by quadratic attention and fixed training context lengths. The authors introduce Tensorized Attention, which folds sequences into a higher-order tensor and performs attention along each dimension, enabling exponential length extrapolation and efficient computation with a Triton kernel. They formalize a tensor-space attention framework, analyze its low-rank properties via CP decomposition, and demonstrate substantial empirical gains, including up to an $11\times$ speedup over full attention and improved perplexity on long-context benchmarks, by continuing pretraining of OpenLlama-3B, Mistral-7B, and Llama-8B to context lengths of $32{,}768$ and beyond to $128{,}k$. The approach shows strong performance on downstream tasks that benefit from longer context and provides a practical path to scalable long-sequence modeling for pretrained LLMs and beyond. Overall, tensorized attention offers a principled, efficient alternative to full attention for long sequences, with potential applicability to time-series and other sequence-modelling domains.

Abstract

As the demand for processing extended textual data grows, the ability to handle long-range dependencies and maintain computational efficiency is more critical than ever. One of the key issues for long-sequence modeling using attention-based model is the mismatch between the limited-range modeling power of full attention and the long-range token dependency in the input sequence. In this work, we propose to scale up the attention receptive field by tensorizing long input sequences into compact tensor representations followed by attention on each transformed dimension. The resulting Tensorized Attention can be adopted as efficient transformer backbones to extend input context length with improved memory and time efficiency. We show that the proposed attention tensorization encodes token dependencies as a multi-hop attention process, and is equivalent to Kronecker decomposition of full attention. Extensive experiments show that tensorized attention can be used to adapt pretrained LLMs with improved efficiency. Notably, Llama-8B with tensorization is trained under 32,768 context length and can steadily extrapolate to 128k length during inference with $11\times$ speedup, compared to full attention with FlashAttention-2.

Long Sequence Modeling with Attention Tensorization: From Sequence to Tensor Learning

TL;DR

Long-context processing in LLMs is challenged by quadratic attention and fixed training context lengths. The authors introduce Tensorized Attention, which folds sequences into a higher-order tensor and performs attention along each dimension, enabling exponential length extrapolation and efficient computation with a Triton kernel. They formalize a tensor-space attention framework, analyze its low-rank properties via CP decomposition, and demonstrate substantial empirical gains, including up to an speedup over full attention and improved perplexity on long-context benchmarks, by continuing pretraining of OpenLlama-3B, Mistral-7B, and Llama-8B to context lengths of and beyond to . The approach shows strong performance on downstream tasks that benefit from longer context and provides a practical path to scalable long-sequence modeling for pretrained LLMs and beyond. Overall, tensorized attention offers a principled, efficient alternative to full attention for long sequences, with potential applicability to time-series and other sequence-modelling domains.

Abstract

As the demand for processing extended textual data grows, the ability to handle long-range dependencies and maintain computational efficiency is more critical than ever. One of the key issues for long-sequence modeling using attention-based model is the mismatch between the limited-range modeling power of full attention and the long-range token dependency in the input sequence. In this work, we propose to scale up the attention receptive field by tensorizing long input sequences into compact tensor representations followed by attention on each transformed dimension. The resulting Tensorized Attention can be adopted as efficient transformer backbones to extend input context length with improved memory and time efficiency. We show that the proposed attention tensorization encodes token dependencies as a multi-hop attention process, and is equivalent to Kronecker decomposition of full attention. Extensive experiments show that tensorized attention can be used to adapt pretrained LLMs with improved efficiency. Notably, Llama-8B with tensorization is trained under 32,768 context length and can steadily extrapolate to 128k length during inference with speedup, compared to full attention with FlashAttention-2.

Paper Structure

This paper contains 26 sections, 2 theorems, 9 equations, 13 figures, 6 tables, 1 algorithm.

Key Result

Theorem 3.1

For attention matrix $\mathbf{A}\in\mathbb{R}^{n\times n}$ and any column $\mathbf{y}\in\mathbb{R}^{n}$ of value vector along the feature dimension, there exists a low-rank matrix $\tilde{\mathbf{A}}\in\mathbb{R}^{n\times n}$ with rank $\mathcal{O}(3^{m}\text{log}_{2m}n)$ defined in tensor-$m$ space

Figures (13)

  • Figure 1: (a) The interaction distance from A to B is decreased from 12 in the sequence format to 5 in the tensor format and fits into context length. (b) Token interactions along each dimension are equivalent to multi-scale interaction in the original sequence.
  • Figure 2: (a) Input sequences $\mathbf{q}, \mathbf{k}, \mathbf{v}$ are first tensorized into $\bm{\mathcal{Q}}, \bm{\mathcal{K}}, \bm{\mathcal{V}}$. Each row in the middle represents the attention along one matching dimension of tensors, and all dimensions except the matching dimension of $\bm{\mathcal{Q}}$ and $\bm{\mathcal{K}}$ are flattened. The result from each row is used to sequentially update the value tensor $\bm{\mathcal{V}}$. (b) Different types of attention processes can be visualized using a tensor diagram, where each circle represents data content and each edge represents a dimension.
  • Figure 3: Comparison of token position in 1-D sequence and 3-D tensor under polar coordinates.
  • Figure 4: The attention or mask (a) can be decomposed to a set of columns or 2-D blocks which can be used to span vector or tensor space. (b) The total number of needed singular vectors to reconstruct the given example pattern $R_t=1<R_v=12$, shows the advantage of using tensor space for diagonal and structured patterns. (c,d) CP decomposition of real attention patterns (layer 3, head 2) from (c) dataset average or (d) single-sample pattern. (e, f) Calculated spectrum as sorted normalized singular values $\Tilde{\lambda}_i$. (g, h) Total percentage of information contained v.s. number of parameters needed for approximation.
  • Figure 5: (a) Perplexity on Proof-pile test dataset after the continued pretraining. Comparison of GPU memory usage (b) and running time (c) efficiency for full and tensorized attention. The time and memory are calculated by averaging 50 forward passes and batch size 5.
  • ...and 8 more figures

Theorems & Definitions (2)

  • Theorem 3.1
  • Lemma B.1