Table of Contents
Fetching ...

Token Statistics Transformer: Linear-Time Attention via Variational Rate Reduction

Ziyang Wu, Tianjiao Ding, Yifu Lu, Druv Pai, Jingyuan Zhang, Weida Wang, Yaodong Yu, Yi Ma, Benjamin D. Haeffele

TL;DR

This paper addresses the quadratic complexity of standard transformer attention by introducing Token Statistics Transformer (ToST), a linear-time attention mechanism derived from a novel variational reformulation of the MCR$^2$ objective. By unrolling the optimization of a variational compression term, it yields Token Statistics Self-Attention (TSSA), which relies on a data-driven second-moment statistic rather than pairwise token similarities, enabling efficient, scalable attention with linear memory. The approach achieves competitive performance on vision and language benchmarks, offers interpretable token clustering via the membership matrix $oldsymbol{cpi}$, and supports a causal variant for autoregressive modeling. This work demonstrates that principled, white-box design can yield efficient, interpretable transformers without sacrificing major accuracy, with broad potential for long sequences and large-scale applications.

Abstract

The attention operator is arguably the key distinguishing factor of transformer architectures, which have demonstrated state-of-the-art performance on a variety of tasks. However, transformer attention operators often impose a significant computational burden, with the computational complexity scaling quadratically with the number of tokens. In this work, we propose a novel transformer attention operator whose computational complexity scales linearly with the number of tokens. We derive our network architecture by extending prior work which has shown that a transformer style architecture naturally arises by "white-box" architecture design, where each layer of the network is designed to implement an incremental optimization step of a maximal coding rate reduction objective (MCR$^2$). Specifically, we derive a novel variational form of the MCR$^2$ objective and show that the architecture that results from unrolled gradient descent of this variational objective leads to a new attention module called Token Statistics Self-Attention (TSSA). TSSA has linear computational and memory complexity and radically departs from the typical attention architecture that computes pairwise similarities between tokens. Experiments on vision, language, and long sequence tasks show that simply swapping TSSA for standard self-attention, which we refer to as the Token Statistics Transformer (ToST), achieves competitive performance with conventional transformers while being significantly more computationally efficient and interpretable. Our results also somewhat call into question the conventional wisdom that pairwise similarity style attention mechanisms are critical to the success of transformer architectures. Code will be available at https://github.com/RobinWu218/ToST.

Token Statistics Transformer: Linear-Time Attention via Variational Rate Reduction

TL;DR

This paper addresses the quadratic complexity of standard transformer attention by introducing Token Statistics Transformer (ToST), a linear-time attention mechanism derived from a novel variational reformulation of the MCR objective. By unrolling the optimization of a variational compression term, it yields Token Statistics Self-Attention (TSSA), which relies on a data-driven second-moment statistic rather than pairwise token similarities, enabling efficient, scalable attention with linear memory. The approach achieves competitive performance on vision and language benchmarks, offers interpretable token clustering via the membership matrix , and supports a causal variant for autoregressive modeling. This work demonstrates that principled, white-box design can yield efficient, interpretable transformers without sacrificing major accuracy, with broad potential for long sequences and large-scale applications.

Abstract

The attention operator is arguably the key distinguishing factor of transformer architectures, which have demonstrated state-of-the-art performance on a variety of tasks. However, transformer attention operators often impose a significant computational burden, with the computational complexity scaling quadratically with the number of tokens. In this work, we propose a novel transformer attention operator whose computational complexity scales linearly with the number of tokens. We derive our network architecture by extending prior work which has shown that a transformer style architecture naturally arises by "white-box" architecture design, where each layer of the network is designed to implement an incremental optimization step of a maximal coding rate reduction objective (MCR). Specifically, we derive a novel variational form of the MCR objective and show that the architecture that results from unrolled gradient descent of this variational objective leads to a new attention module called Token Statistics Self-Attention (TSSA). TSSA has linear computational and memory complexity and radically departs from the typical attention architecture that computes pairwise similarities between tokens. Experiments on vision, language, and long sequence tasks show that simply swapping TSSA for standard self-attention, which we refer to as the Token Statistics Transformer (ToST), achieves competitive performance with conventional transformers while being significantly more computationally efficient and interpretable. Our results also somewhat call into question the conventional wisdom that pairwise similarity style attention mechanisms are critical to the success of transformer architectures. Code will be available at https://github.com/RobinWu218/ToST.

Paper Structure

This paper contains 47 sections, 3 theorems, 28 equations, 11 figures, 8 tables, 4 algorithms.

Key Result

Theorem 1

Let $f \colon [0, \infty) \to \mathbb{R}$ be non-decreasing, concave, and obey $f(0) = 0$, and let $F \colon \mathrm{PSD}(d) \to \mathbb{R}$ have the form $F(\mathbf{M}) = \sum_{i = 1}^{d}f(\lambda_{i}(\mathbf{M}))$. Then for each $\mathbf{M} \in \mathrm{PSD}(d)$ and $\mathbf{Q} \in \mathrm{O}(d)$, Further, the inequality in equation eq:U_bound is achieved with equality for any $\mathbf{Q}$ which

Figures (11)

  • Figure 1: Our ToST architecture, from unrolling a novel variational form of MCR$^2$, is faster and uses less memory than standard transformer architectures such as ViT (note the log-scale y-axes) and is based on a dramatically different notion of attention.
  • Figure 2: Tokenization and representation for image data. Data (images) are split up into tokens (patches) $\mathbf{X}$, which share semantics with other tokens from the same or different samples. Tokens with similar semantics may belong to the same geometric structures in the original space and be grouped together by $\mathbf{\Pi}$. A learned mapping $\phi$ converts these tokens into features which are compressed, linearized, and discriminative.
  • Figure 3: One layer $\ell$ of the proposed Token Statistics Transformer (ToST). Notably, the self-attention of ToST transforms tokens $\mathbf{Z}^{\ell}$ efficiently to $\mathbf{Z}^{\ell+1}$, via multiplying each row of the projected token by only a scalar. This leads to reduced complexity of the attention (cf. \ref{['tab:transformer_complexity']}): it has $\mathcal{O}(p)$ space and $\mathcal{O}(pn)$ time complexity, where $p$ is the dimension of the projected tokens of each head, and $n$ is the number of tokens.
  • Figure 4: (Left) The variational compression term $R_{c,f}^{\mathrm{var}} (\mathbf{Z}^{\ell},\mathbf{\Pi}^{\ell})$ of the TSSA outputs $\mathbf{Z}^{\ell}$ and estimated memberships $\mathbf{\Pi}^{\ell}$ at different layers $\ell$ of the ToST-S model. (Right) Visualization of estimated $\mathbf{\Pi}$ for several images. For an image with $N$ tokens (patches), we visualize each row of the membership matrix $\mathbf{\Pi}$ in the TSSA layer by reshaping it into a $\sqrt{N} \times \sqrt{N}$ matrix. Here we visualize membership matrices $\mathbf{\Pi}$ in ToST-S, estimated in layer 9, of each input image.
  • Figure 5: Comparison of [CLS] token attention map visualization. We take the last head in the penultimate global class attention layer for visualization from ToST-S,XCiT-S, and ViT-S, respectively.
  • ...and 6 more figures

Theorems & Definitions (4)

  • Theorem 1
  • Corollary 1
  • Theorem 2
  • proof