Table of Contents
Fetching ...

Interleaved Head Attention

Sai Surya Duvvuri, Chanakya Ekbote, Rachit Bansal, Rishabh Tiwari, Devvrit Khatri, David Brandfonbrener, Paul Liang, Inderjit Dhillon, Manzil Zaheer

TL;DR

IHA enables cross-head mixing by constructing $P$ pseudo-heads per head (typically $P=H$), where each pseudo query/key/value is a learned linear combination of all $H$ original queries, keys and values respectively.

Abstract

Multi-Head Attention (MHA) is the core computational primitive underlying modern Large Language Models (LLMs). However, MHA suffers from a fundamental linear scaling limitation: $H$ attention heads produce exactly $H$ independent attention matrices, with no communication between heads during attention computation. This becomes problematic for multi-step reasoning, where correct answers depend on aggregating evidence from multiple parts of the context and composing latent token-to-token relations over a chain of intermediate inferences. To address this, we propose Interleaved Head Attention (IHA), which enables cross-head mixing by constructing $P$ pseudo-heads per head (typically $P=H$), where each pseudo query/key/value is a learned linear combination of all $H$ original queries, keys and values respectively. Interactions between pseudo-query and pseudo-key heads induce up to $P^2$ attention patterns per head with modest parameter overhead $\mathcal{O}(H^2P)$. We provide theory showing improved efficiency in terms of number of parameters on the synthetic Polynomial task (IHA uses $Θ(\sqrt{k}n^2)$ parameters vs. $Θ(kn^2)$ for MHA) and on the synthetic order-sensitive CPM-3 task (IHA uses $\lceil\sqrt{N_{\max}}\rceil$ heads vs. $N_{\max}$ for MHA). On real-world benchmarks, IHA improves Multi-Key retrieval on RULER by 10-20% (4k-16k) and, after fine-tuning for reasoning on OpenThoughts, improves GSM8K by 5.8% and MATH-500 by 2.8% (Majority Vote) over full attention.

Interleaved Head Attention

TL;DR

IHA enables cross-head mixing by constructing pseudo-heads per head (typically ), where each pseudo query/key/value is a learned linear combination of all original queries, keys and values respectively.

Abstract

Multi-Head Attention (MHA) is the core computational primitive underlying modern Large Language Models (LLMs). However, MHA suffers from a fundamental linear scaling limitation: attention heads produce exactly independent attention matrices, with no communication between heads during attention computation. This becomes problematic for multi-step reasoning, where correct answers depend on aggregating evidence from multiple parts of the context and composing latent token-to-token relations over a chain of intermediate inferences. To address this, we propose Interleaved Head Attention (IHA), which enables cross-head mixing by constructing pseudo-heads per head (typically ), where each pseudo query/key/value is a learned linear combination of all original queries, keys and values respectively. Interactions between pseudo-query and pseudo-key heads induce up to attention patterns per head with modest parameter overhead . We provide theory showing improved efficiency in terms of number of parameters on the synthetic Polynomial task (IHA uses parameters vs. for MHA) and on the synthetic order-sensitive CPM-3 task (IHA uses heads vs. for MHA). On real-world benchmarks, IHA improves Multi-Key retrieval on RULER by 10-20% (4k-16k) and, after fine-tuning for reasoning on OpenThoughts, improves GSM8K by 5.8% and MATH-500 by 2.8% (Majority Vote) over full attention.
Paper Structure (62 sections, 124 equations, 11 figures, 2 tables)

This paper contains 62 sections, 124 equations, 11 figures, 2 tables.

Figures (11)

  • Figure 1: Overview of Interleaved Head Attention (IHA). First, the model generates $P$ pseudo-tokens for each of the $H$ original heads via a learned linear transformation ($\times \mathbf{\mathcal{\alpha}_Q}$) operating on the heads axis (green). These tokens are then interleaved to create an expanded sequence of length $P \cdot N$. Finally, standard causal self-attention is computed on this expanded sequence, utilizing a sliding window (e.g., $N/2P$) to manage computational complexity while enabling cross-head interaction. Different linear transforms are used in query, key and values.
  • Figure 2: RULER long-context results after 64k fine-tuning. (a) Multi-Key Retrieval accuracy at 4k/8k/16k context lengths (orange: IHA improvement over Sliding Window). (b) Overall RULER Exact Match (EM) show strong improvements using IHA.
  • Figure : (a) Binary composition: final test accuracy across learning rates.
  • Figure : (a) Binary composition, $\eta=10^{-3}$, $L=1$, $H=8$.
  • Figure : (a) Ternary composition, $\eta=10^{-3}$, $L=1$, $H=8$.
  • ...and 6 more figures

Theorems & Definitions (3)

  • proof
  • proof
  • proof