Table of Contents
Fetching ...

Training Tensor Attention Efficiently: From Cubic to Almost Linear Time

Yang Cao, Yingyu Liang, Zhenmei Shi, Zhao Song

TL;DR

This work addresses the cubic-time barrier in training Tensor Attention by deriving a closed-form gradient and proving that the backward gradient can be computed in almost linear time $n^{1+o(1)}$ under a bounded-entry assumption. It introduces a fast-gradient algorithm based on polynomial approximation and low-rank tensor techniques, matching the forward pass’s efficiency and enabling scalable training of higher-order transformers. A formal hardness analysis under the Strong Exponential Time Hypothesis (SETH) shows the bounded-entry assumption is tight in the sense that slight relaxation would preclude truly subcubic algorithms for both forward and backward computations. Collectively, the results establish the feasibility of efficient higher-order transformer training and broaden the practical potential of tensor-attention architectures for multi-modal modeling and other domains.

Abstract

Tensor Attention, a multi-view attention that is able to capture high-order correlations among multiple modalities, can overcome the representational limitations of classical matrix attention. However, the $O(n^3)$ time complexity of tensor attention poses a significant obstacle to its utilization in transformers, where $n$ is the input sequence length. In this work, we prove that the backward gradient of tensor attention training can be computed in almost linear time $n^{1+o(1)}$, the same complexity as its forward computation under the bounded entries assumption. We provide a closed-form solution for the gradient and propose a fast computation method utilizing polynomial approximation methods and tensor algebraic techniques. Furthermore, we prove the necessity and tightness of our assumption through hardness analysis, showing that slightly weakening it renders the gradient problem unsolvable in truly subcubic time. Our theoretical results establish the feasibility of efficient higher-order transformer training and may facilitate practical applications of tensor attention architectures.

Training Tensor Attention Efficiently: From Cubic to Almost Linear Time

TL;DR

This work addresses the cubic-time barrier in training Tensor Attention by deriving a closed-form gradient and proving that the backward gradient can be computed in almost linear time under a bounded-entry assumption. It introduces a fast-gradient algorithm based on polynomial approximation and low-rank tensor techniques, matching the forward pass’s efficiency and enabling scalable training of higher-order transformers. A formal hardness analysis under the Strong Exponential Time Hypothesis (SETH) shows the bounded-entry assumption is tight in the sense that slight relaxation would preclude truly subcubic algorithms for both forward and backward computations. Collectively, the results establish the feasibility of efficient higher-order transformer training and broaden the practical potential of tensor-attention architectures for multi-modal modeling and other domains.

Abstract

Tensor Attention, a multi-view attention that is able to capture high-order correlations among multiple modalities, can overcome the representational limitations of classical matrix attention. However, the time complexity of tensor attention poses a significant obstacle to its utilization in transformers, where is the input sequence length. In this work, we prove that the backward gradient of tensor attention training can be computed in almost linear time , the same complexity as its forward computation under the bounded entries assumption. We provide a closed-form solution for the gradient and propose a fast computation method utilizing polynomial approximation methods and tensor algebraic techniques. Furthermore, we prove the necessity and tightness of our assumption through hardness analysis, showing that slightly weakening it renders the gradient problem unsolvable in truly subcubic time. Our theoretical results establish the feasibility of efficient higher-order transformer training and may facilitate practical applications of tensor attention architectures.
Paper Structure (51 sections, 24 theorems, 114 equations, 4 figures, 1 table, 1 algorithm)

This paper contains 51 sections, 24 theorems, 114 equations, 4 figures, 1 table, 1 algorithm.

Key Result

Lemma 3.1

Define the function $\mathsf{F}(x) \in \mathbb{R}^{n \times n^2}$ as in Definition def:p (see Fig. fig:tat_backward for an illustration). Suppose that $A_1, A_2, A_3 \in \mathbb{R}^{n \times d}$ are three given matrices. Suppose that $\mathsf{Loss}(x)$ is defined as Definition def:attention_optimiza

Figures (4)

  • Figure 1: The visualization of tensor attention with $\mathsf{Softmax}$ activation function (Definition \ref{['def:tensor_att']}). We give an example of input token length $n=8$, feature dimension $d=4$.
  • Figure 2: The visualization of vectorization operator $\mathop{\mathrm{vec}}\nolimits(\cdot)$, which stacks rows of a matrix $A\in \mathbb{R}^{n \times d}$ into a column vector $a \in \mathbb{R}^{nd}$. In this figure, we give an example of $n=3,d=4$. The components of $A$ and $a$ are also given for easier understanding.
  • Figure 3: The visualization of loss function defined in Definition \ref{['def:attention_optimization_loss']}. Let $A_1, A_2, A_3, A_4, A_5$ and $E$ be $n \times d$ input matrices. Let $Y$ be a given matrix with size $d^2 \times d$. The Kronecker product operator $\otimes$ is defined in Definition \ref{['def:tensor_otimes']}. We minimize matrix $X\in \mathbb{R}^{d \times d^2}$ in our loss function. We first compute $\exp(A_1 X (A_2 \otimes A_3)^\top)$. Then, we compute $D(X) := \mathop{\mathrm{diag}}\nolimits(\exp(A_1 X (A_2 \otimes A_3)^\top) {\bf 1}_{n^2})$. Afterwards, we compute $D(X)^{-1} \exp(A_1 X (A_2 \otimes A_3)^\top) (A_4 \otimes A_5)Y - E$. Finally, we optimize $X$ to compute the minimum of its Frobenius norm with a scaling factor $0.5$.
  • Figure 4: The computational graph for tensor attention backward. The blue boxes are input matrices, the gray boxes are intermediate matrices, and the orange box is the final gradient matrix. Here, $A_1,A_2,A_3,A_4,A_5$ denote the previous inputs, $E$ denotes the target matrix, and $X,Y$ denote the attention weights. More detailed definitions of each variable can be found in Section \ref{['sec:gradient']}, \ref{['app:app_time']} and \ref{['app:app_fast_time']}.

Theorems & Definitions (86)

  • Definition 2.1: $\otimes$ Kronecker product
  • Definition 2.2: $\oslash$ column-wise Kronecker product, also known as Kathri-Rao product
  • Definition 2.3: $\ominus$ row-wise Kronecker product, also referred to as the face-splitting product
  • Definition 2.4: Input and weight matrix
  • Definition 2.5: Tensor attention, Definition 7 in sht24, Definition 1.1 in as24_iclr
  • Remark 2.6
  • Definition 2.7: Approximate Tensor Attention Computation ($\mathsf{ATAttC}(n,d,B,\epsilon)$), Definition 1.2 in as24_iclr
  • Definition 2.8: Tensor attention optimization
  • Definition 2.9: Approximate Tensor Attention Loss Gradient Computation ($\mathsf{ATAttLGC}(n,d,B,\epsilon)$)
  • Lemma 3.1: Closed form of gradient, informal version of Lemma \ref{['lem:compute_gradient']}
  • ...and 76 more