Scaling Stick-Breaking Attention: An Efficient Implementation and In-depth Study
Shawn Tan, Songlin Yang, Aaron Courville, Rameswar Panda, Yikang Shen
TL;DR
The paper introduces stick-breaking attention as a stable, no-position-embedding alternative to softmax-based self-attention for long-context modelling. It formalises the attention weights via a stick-breaking construction, implements a numerically stable, log-space forward pass, and adapts a Triton/Flash Attention-like kernel to scale to large sequences. Empirical results show competitive length generalisation, improved retrieval in long-context benchmarks, and favorable perplexity and few-shot performance for 1B–3B models. The work highlights practical gains for long-sequence language modelling and outlines pathways for further efficiency and architectural enhancements.
Abstract
The self-attention mechanism traditionally relies on the softmax operator, necessitating positional embeddings like RoPE, or position biases to account for token order. But current methods using still face length generalisation challenges. We investigate an alternative attention mechanism based on the stick-breaking process in larger scale settings. The method works as follows: For each token before the current, we determine a break point, which represents the proportion of the stick, the weight of the attention, to allocate to the current token. We repeat this on the remaining stick, until all tokens are allocated a weight, resulting in a sequence of attention weights. This process naturally incorporates recency bias, which has linguistic motivations for grammar parsing. We study the implications of replacing the conventional softmax-based attention mechanism with stick-breaking attention. We then discuss implementation of numerically stable stick-breaking attention and adapt Flash Attention to accommodate this mechanism. When used as a drop-in replacement for current softmax+RoPE attention systems, we find that stick-breaking attention performs competitively with current methods on length generalisation and downstream tasks. Stick-breaking also performs well at length generalisation, allowing a model trained with $2^{11}$ context window to perform well at $2^{14}$ with perplexity improvements.
