LOTFormer: Doubly-Stochastic Linear Attention via Low-Rank Optimal Transport
Ashkan Shahbazi, Chayne Thrash, Yikun Bai, Keaton Hamm, Navid NaderiAlizadeh, Soheil Kolouri
TL;DR
LOTFormer tackles the quadratic complexity of traditional attention by framing attention as a transport problem and introducing a learnable pivot measure with $r$ support to enforce a low-rank, doubly stochastic coupling. By solving two entropic OT problems (queries to pivot and pivot to keys) and composing them into a glued coupling, it achieves $O(n d_k r)$ time without forming the full $n imes n$ attention matrix, while remaining end-to-end trainable. Empirically, LOTFormer delivers competitive or superior results on ImageNet 1K across multiple backbones, matches or surpasses state-of-the-art linear and DS baselines on Long Range Arena, and can be plugged into pretrained checkpoints for text benchmarks with modest tuning. This approach offers a practical path to robust, scalable attention that improves information flow and interpretability through a structured, pivot-mediated transport view.
Abstract
Transformers have proven highly effective across modalities, but standard softmax attention scales quadratically with sequence length, limiting long context modeling. Linear attention mitigates this by approximating attention with kernel feature maps, yet most attention mechanisms remain row normalized and can over concentrate mass on a few tokens, harming robustness and information flow. Doubly stochastic attention counteracts this by balancing token participation across both rows and columns, but existing approaches often add significant overhead. We propose LOTFormer, a linear time doubly stochastic attention mechanism derived from an optimal transport view of attention as a coupling between query and key measures. LOTFormer enforces a low rank transport plan by conditioning on a learnable pivot measure with small support. We solve two entropic transport problems, queries to pivot and pivot to keys, and compose them into a conditional coupling that is provably doubly stochastic, has rank at most $r \ll n$, and applies to values in $O(nr)$ time without forming the full $n \times n$ matrix. The pivot locations and masses are learned end-to-end. Across vision and text benchmarks, LOTFormer delivers strong accuracy efficiency tradeoffs when plugged into standard backbones including Swin, DeiT, and BERT.
