Table of Contents
Fetching ...

LAWCAT: Efficient Distillation from Quadratic to Linear Attention with Convolution across Tokens for Long Context Modeling

Zeyu Liu, Souvik Kundu, Lianghao Jiang, Anni Li, Srikanth Ronanki, Sravan Bodapati, Gourav Datta, Peter A. Beerel

TL;DR

This work tackles the quadratic time/space cost of transformer self-attention for long contexts by distilling pretrained transformers into a linear-attention framework augmented with a causal Conv1D layer for local context and a normalized gated linear attention mechanism. The proposed LAWCAT framework enables efficient long-context modeling and edge deployment, achieving over 90% passkey retrieval accuracy up to $22{,}000$ tokens with minimal pretraining data, and competitive performance on challenging S-NIAH and BABILong benchmarks. Empirical results show LAWCAT outperforms LoLCATs and several recurrent baselines on long-context tasks, while offering faster prefill latency than FlashAttention-2 for long sequences, revealing a practical path to high-performance, resource-efficient long-context linear models. Limitations include task-dependent effects of RoPE and SWA, suggesting future work on adaptive hybrid attention and more advanced linear attention variants to further close the gap with full softmax attention across diverse benchmarks.

Abstract

Although transformer architectures have achieved state-of-the-art performance across diverse domains, their quadratic computational complexity with respect to sequence length remains a significant bottleneck, particularly for latency-sensitive long-context applications. While recent linear-complexity alternatives are increasingly powerful, effectively training them from scratch is still resource-intensive. To overcome these limitations, we propose LAWCAT (Linear Attention with Convolution Across Time), a novel linearization framework designed to efficiently transfer the capabilities of pre-trained transformers into a performant linear attention architecture. LAWCAT integrates causal Conv1D layers to enhance local dependency modeling and employs normalized gated linear attention to improve generalization across varying context lengths. Our comprehensive evaluations demonstrate that, distilling Mistral-7B with only 1K-length sequences yields over 90\% passkey retrieval accuracy up to 22K tokens, significantly extending its effective context window. Similarly, Llama3.2-1B LAWCAT variant achieves competitive performance on S-NIAH 1\&2\&3 tasks (1K-8K context length) and BABILong benchmark (QA2\&QA3, 0K-16K context length), requiring less than 0.1\% pre-training tokens compared with pre-training models. Furthermore, LAWCAT exhibits faster prefill speeds than FlashAttention-2 for sequences exceeding 8K tokens. LAWCAT thus provides an efficient pathway to high-performance, long-context linear models suitable for edge deployment, reducing reliance on extensive long-sequence training data and computational resources. Code is released at: https://github.com/zeyuliu1037/LAWCAT

LAWCAT: Efficient Distillation from Quadratic to Linear Attention with Convolution across Tokens for Long Context Modeling

TL;DR

This work tackles the quadratic time/space cost of transformer self-attention for long contexts by distilling pretrained transformers into a linear-attention framework augmented with a causal Conv1D layer for local context and a normalized gated linear attention mechanism. The proposed LAWCAT framework enables efficient long-context modeling and edge deployment, achieving over 90% passkey retrieval accuracy up to tokens with minimal pretraining data, and competitive performance on challenging S-NIAH and BABILong benchmarks. Empirical results show LAWCAT outperforms LoLCATs and several recurrent baselines on long-context tasks, while offering faster prefill latency than FlashAttention-2 for long sequences, revealing a practical path to high-performance, resource-efficient long-context linear models. Limitations include task-dependent effects of RoPE and SWA, suggesting future work on adaptive hybrid attention and more advanced linear attention variants to further close the gap with full softmax attention across diverse benchmarks.

Abstract

Although transformer architectures have achieved state-of-the-art performance across diverse domains, their quadratic computational complexity with respect to sequence length remains a significant bottleneck, particularly for latency-sensitive long-context applications. While recent linear-complexity alternatives are increasingly powerful, effectively training them from scratch is still resource-intensive. To overcome these limitations, we propose LAWCAT (Linear Attention with Convolution Across Time), a novel linearization framework designed to efficiently transfer the capabilities of pre-trained transformers into a performant linear attention architecture. LAWCAT integrates causal Conv1D layers to enhance local dependency modeling and employs normalized gated linear attention to improve generalization across varying context lengths. Our comprehensive evaluations demonstrate that, distilling Mistral-7B with only 1K-length sequences yields over 90\% passkey retrieval accuracy up to 22K tokens, significantly extending its effective context window. Similarly, Llama3.2-1B LAWCAT variant achieves competitive performance on S-NIAH 1\&2\&3 tasks (1K-8K context length) and BABILong benchmark (QA2\&QA3, 0K-16K context length), requiring less than 0.1\% pre-training tokens compared with pre-training models. Furthermore, LAWCAT exhibits faster prefill speeds than FlashAttention-2 for sequences exceeding 8K tokens. LAWCAT thus provides an efficient pathway to high-performance, long-context linear models suitable for edge deployment, reducing reliance on extensive long-sequence training data and computational resources. Code is released at: https://github.com/zeyuliu1037/LAWCAT

Paper Structure

This paper contains 20 sections, 8 equations, 8 figures, 10 tables.

Figures (8)

  • Figure 1: The overall structure of LAWCAT. We use Casual Conv1D with a kernel size of $r+1$ for visualization, and use $\mathbf{x}_t$ to represent the $t$-th token of input $\mathbf{x}$.
  • Figure 2: The comparison between the pre-trained transformer model and our converted linear attention model.
  • Figure 3: Comparison of accuracy on QA2&3 from BABILong benchmark between Llama 3.2 1B (top), LoLCATs (middle), and LAWCAT (bottom).
  • Figure 4: Comparison of prefill-stage latency among 5 different models. Note, LoLCATs runs out of memory for sequence lengths exceeding 8K tokens.
  • Figure 5: Visualization of attention scores from layer 15, head 5 across three models: Transformer (top), LoLCATs (middle), and LAWCAT (bottom). Each row presents three attention maps: needle-to-needle (left), answer-to-needle (center), and answer-to-answer (right)
  • ...and 3 more figures