Table of Contents
Fetching ...

Scaling Stick-Breaking Attention: An Efficient Implementation and In-depth Study

Shawn Tan, Songlin Yang, Aaron Courville, Rameswar Panda, Yikang Shen

TL;DR

The paper introduces stick-breaking attention as a stable, no-position-embedding alternative to softmax-based self-attention for long-context modelling. It formalises the attention weights via a stick-breaking construction, implements a numerically stable, log-space forward pass, and adapts a Triton/Flash Attention-like kernel to scale to large sequences. Empirical results show competitive length generalisation, improved retrieval in long-context benchmarks, and favorable perplexity and few-shot performance for 1B–3B models. The work highlights practical gains for long-sequence language modelling and outlines pathways for further efficiency and architectural enhancements.

Abstract

The self-attention mechanism traditionally relies on the softmax operator, necessitating positional embeddings like RoPE, or position biases to account for token order. But current methods using still face length generalisation challenges. We investigate an alternative attention mechanism based on the stick-breaking process in larger scale settings. The method works as follows: For each token before the current, we determine a break point, which represents the proportion of the stick, the weight of the attention, to allocate to the current token. We repeat this on the remaining stick, until all tokens are allocated a weight, resulting in a sequence of attention weights. This process naturally incorporates recency bias, which has linguistic motivations for grammar parsing. We study the implications of replacing the conventional softmax-based attention mechanism with stick-breaking attention. We then discuss implementation of numerically stable stick-breaking attention and adapt Flash Attention to accommodate this mechanism. When used as a drop-in replacement for current softmax+RoPE attention systems, we find that stick-breaking attention performs competitively with current methods on length generalisation and downstream tasks. Stick-breaking also performs well at length generalisation, allowing a model trained with $2^{11}$ context window to perform well at $2^{14}$ with perplexity improvements.

Scaling Stick-Breaking Attention: An Efficient Implementation and In-depth Study

TL;DR

The paper introduces stick-breaking attention as a stable, no-position-embedding alternative to softmax-based self-attention for long-context modelling. It formalises the attention weights via a stick-breaking construction, implements a numerically stable, log-space forward pass, and adapts a Triton/Flash Attention-like kernel to scale to large sequences. Empirical results show competitive length generalisation, improved retrieval in long-context benchmarks, and favorable perplexity and few-shot performance for 1B–3B models. The work highlights practical gains for long-sequence language modelling and outlines pathways for further efficiency and architectural enhancements.

Abstract

The self-attention mechanism traditionally relies on the softmax operator, necessitating positional embeddings like RoPE, or position biases to account for token order. But current methods using still face length generalisation challenges. We investigate an alternative attention mechanism based on the stick-breaking process in larger scale settings. The method works as follows: For each token before the current, we determine a break point, which represents the proportion of the stick, the weight of the attention, to allocate to the current token. We repeat this on the remaining stick, until all tokens are allocated a weight, resulting in a sequence of attention weights. This process naturally incorporates recency bias, which has linguistic motivations for grammar parsing. We study the implications of replacing the conventional softmax-based attention mechanism with stick-breaking attention. We then discuss implementation of numerically stable stick-breaking attention and adapt Flash Attention to accommodate this mechanism. When used as a drop-in replacement for current softmax+RoPE attention systems, we find that stick-breaking attention performs competitively with current methods on length generalisation and downstream tasks. Stick-breaking also performs well at length generalisation, allowing a model trained with context window to perform well at with perplexity improvements.

Paper Structure

This paper contains 26 sections, 11 equations, 8 figures, 4 tables, 2 algorithms.

Figures (8)

  • Figure 1: Differences in formulation between stick-breaking and softmax. Stick-breaking assigns high weights to the most recent high logit, while softmax will assign equal weightage to equal logits. $\sigma(\cdot)$ can be any function $\mathbb{R} \rightarrow (0,1)$. In this paper we use $\sigma(x) = \frac{1}{1+exp(-x)}$
  • Figure 2: Thread tile assignments for a given attention head and a sequence. Tiles coloured the same are processed by the same thread. For stick-breaking, the forward pass has to be computed from right-to-left, while the backward pass is computed in the reverese order. Uncoloured tiles are not computed: upper right tiles are not used in causal language modelling, and in the case of block skipping, some blocks can be skipped if all entries have summed to 1.
  • Figure 3: MQRAR performance on increasing key-value pairs.
  • Figure 4: Attention visualisation of the models trained on the MQRAR task. The figure shows the attention for each token for the 2-layer 100-dimension Transformer. Note that in the standard Softmax+RoPE setting (above), the attention head is "distracted" at the third retrieval of 'E', attending to the first instance of 'E' rather than the mroe recent one. In the stick-breaking setting (below), each attention head attends to the prior assignment of the variable.
  • Figure 5: Comparisons against different methods of sequence length extension. $L$ represents the training context length, $f$ is the RoPE scaling factor used, and $W$ is the window size in sliding window attention. We compare against different position embeddings and biases with $L=8192$, training with $L=8192$, and various sliding window sizes with $L=2048$. Note that the scale on the $y$-axis are different in all three plots.
  • ...and 3 more figures