Table of Contents
Fetching ...

On the Power of Convolution Augmented Transformer

Mingchen Li, Xuechen Zhang, Yixiao Huang, Samet Oymak

TL;DR

This work investigates Convolution-Augmented Attention (CAT), a hybrid transformer that injects convolutional filters into the K/Q/V embeddings to blend local and global information. The authors prove that a single CAT layer can solve N-gram associative recall (NAR) and selective copying (SC), and that such CAT solutions exhibit length generalization to arbitrary context lengths. They further show that long convolutions enable efficient sparse-attention regimes like Landmark CAT (LCAT), by summarizing the context into landmarks and attending to them, with concrete computational tradeoffs analyzed under random-context models. Empirically, CAT improves language modeling performance and length generalization on real data, while maintaining strong performance on mechanistic tasks in synthetic setups. Overall, CAT provides design principles for robust, hybrid architectures that leverage both local filtering and global attention.

Abstract

The transformer architecture has catalyzed revolutionary advances in language modeling. However, recent architectural recipes, such as state-space models, have bridged the performance gap. Motivated by this, we examine the benefits of Convolution-Augmented Transformer (CAT) for recall, copying, and length generalization tasks. CAT incorporates convolutional filters in the K/Q/V embeddings of an attention layer. Through CAT, we show that the locality of the convolution synergizes with the global view of the attention. Unlike comparable architectures, such as Mamba or transformer, CAT can provably solve the associative recall (AR) and copying tasks using a single layer while also enjoying guaranteed length generalization. We also establish computational tradeoffs between convolution and attention by characterizing how convolution can mitigate the need for full attention by summarizing the context window and creating salient summary tokens to attend. Evaluations on real datasets corroborate our findings and demonstrate that CAT and its variations indeed enhance the language modeling performance.

On the Power of Convolution Augmented Transformer

TL;DR

This work investigates Convolution-Augmented Attention (CAT), a hybrid transformer that injects convolutional filters into the K/Q/V embeddings to blend local and global information. The authors prove that a single CAT layer can solve N-gram associative recall (NAR) and selective copying (SC), and that such CAT solutions exhibit length generalization to arbitrary context lengths. They further show that long convolutions enable efficient sparse-attention regimes like Landmark CAT (LCAT), by summarizing the context into landmarks and attending to them, with concrete computational tradeoffs analyzed under random-context models. Empirically, CAT improves language modeling performance and length generalization on real data, while maintaining strong performance on mechanistic tasks in synthetic setups. Overall, CAT provides design principles for robust, hybrid architectures that leverage both local filtering and global attention.

Abstract

The transformer architecture has catalyzed revolutionary advances in language modeling. However, recent architectural recipes, such as state-space models, have bridged the performance gap. Motivated by this, we examine the benefits of Convolution-Augmented Transformer (CAT) for recall, copying, and length generalization tasks. CAT incorporates convolutional filters in the K/Q/V embeddings of an attention layer. Through CAT, we show that the locality of the convolution synergizes with the global view of the attention. Unlike comparable architectures, such as Mamba or transformer, CAT can provably solve the associative recall (AR) and copying tasks using a single layer while also enjoying guaranteed length generalization. We also establish computational tradeoffs between convolution and attention by characterizing how convolution can mitigate the need for full attention by summarizing the context window and creating salient summary tokens to attend. Evaluations on real datasets corroborate our findings and demonstrate that CAT and its variations indeed enhance the language modeling performance.
Paper Structure (26 sections, 14 theorems, 96 equations, 11 figures, 2 tables)

This paper contains 26 sections, 14 theorems, 96 equations, 11 figures, 2 tables.

Key Result

Theorem 1

Let $\bm{F}\in\mathbb{R}^N$ be a causal 1-D convolutional filter of length $N$ and $\texttt{norm}({\bm{X}})$ normalize the rows of a matrix to unit $\ell_2$ norm. Consider a single CAT layer $f({\bm{X}})=({\bm{X}}_v\bm{W}_v)^{\top}\mathbb{S}({\bm{X}}_k\bm{W}_k\bm{W}_q^{\top}{\bm{q}})$ where ${\bm{q} Let $\varepsilon>0$ be the minimum $\ell_2$ distance between two distinct tokens embeddings. For al

Figures (11)

  • Figure 1: Evaluations on synthetic and real data. The models are trained on 128 and 2,048 context length (vertical dashed lines) and tested on varying context lengths respectively. Left figure: We conduct synthetic experiments on the Associative Recall task and contrast 1-layer CAT with 2-layers of alternative architectures. The embedding dimension is 128. We find that CAT is the only model that solves AR with length generalization in line with our theory (also see Fig. \ref{['fig:lengen']}). Right figure: Evaluations on language modeling where we train CAT models by equipping Pythia with short convolutions (window size 21). Convolution allows the model to pretrain without positional encoding and further improves perplexity when combined with RoPE. Importantly, it also generalizes to longer context lengths more robustly with or without RoPE. For length generalization, we used YaRN peng2023yarn which incorporates position interpolation chen2023extending (for RoPE only) and temperature scaling (see Sec. \ref{['sec:nlp_exp']}).
  • Figure 2: Left figure: Illustration of the Convolution-Augmentated Attention (CAT) block, where separate filters are applied to the K/Q/V embeddings, before self-attention (see Sec. \ref{['sec:methodology']} for details). Right figure: Performance of 1-layer CAT models trained on multi-query AR (MQAR, see Sec. \ref{['sec:MQAR']} for details) tasks with model embedding dimension 64 and varying sequence length. The LinCAT replaces the standard attention in CAT with linear attention. We observe that the CAT model outperforms the baseline models across all sequence lengths with only 1 layer compared to 2 layers baselines.
  • Figure 3: Illustration of the Landmark CAT. We first apply long convolution on the input sequence and subsample it to obtain landmark tokens representing individual blocks. Hard Attention computes the similarity between the query and landmarks to retrieve the most relevant block. Local Attention concatenates the retrieved block with the final block containing the query and computes the output token.
  • Figure 4: Behavior of the embedding dimension as a function of block size for context length $L=2^{20}\approx 1$ million (noise level $\sigma^2=1)$. Shaded region highlights te range of $d$ that exhibits 10%-50% empirical success. Proposition \ref{['prop long conv1']} accurately captures the empirical behavior. For the success of uniform AR, we need larger $d$ as the dimension of the query space $S$ grows.
  • Figure 5: Evaluation of models on MQAR and MQNAR tasks with varying model dimensions and sequence lengths. Model dimensions are 32, 64, 128 for each column of the figures, from left to right. Top: Models trained on the MQAR setup. Bottom: Models trained on the MQNAR setup. Note that CAT models employ a single-layer architecture, whereas all other models utilize two layers. Refer to Section \ref{['sec:syn_exp']} for detailed setup descriptions.
  • ...and 6 more figures

Theorems & Definitions (30)

  • Definition 1: Convolution-Augmented Attention (CAT)
  • Definition 2: Associative Recall Problem
  • Definition 3: N-gram AR Problem
  • Definition 4: Selective Copying
  • Definition 5: Multi-Query Associative Recall (MQAR)
  • Theorem 1: Solving NAR
  • Corollary 1: 1-D CAT solves AR
  • Theorem 2: Length generalization
  • Theorem 3: Selective Copy
  • Definition 6: Random Context Model
  • ...and 20 more