Improving Transformers with Dynamically Composable Multi-Head Attention
Da Xiao, Qingye Meng, Shengping Li, Xingyuan Yuan
TL;DR
This work tackles the inefficiencies of standard multi-head attention by introducing Dynamically Composable Multi-Head Attention (DCMHA), which dynamically composes attention heads through a Compose function applied to attention score and weight matrices. By leveraging a tensor-decomposition-based, input-dependent transformation, DCFormer acts as a drop-in replacement for MHA, yielding significant gains in language modeling and downstream tasks while achieving ~1.7×–2× compute efficiency. The approach provides both theoretical connections to projection-based head composition and practical mechanisms to maintain efficiency via low-rank plus diagonal decompositions and grouped tensor parallel training. Empirically, DCFormer achieves strong scaling behavior, outperforms baseline Transformers across model sizes, and transfers to vision transformers with improved image classification performance, all while offering interpretable insights into head diversity and the dynamic interaction of QK and OV circuits.
Abstract
Multi-Head Attention (MHA) is a key component of Transformer. In MHA, attention heads work independently, causing problems such as low-rank bottleneck of attention score matrices and head redundancy. We propose Dynamically Composable Multi-Head Attention (DCMHA), a parameter and computation efficient attention architecture that tackles the shortcomings of MHA and increases the expressive power of the model by dynamically composing attention heads. At the core of DCMHA is a $\it{Compose}$ function that transforms the attention score and weight matrices in an input-dependent way. DCMHA can be used as a drop-in replacement of MHA in any transformer architecture to obtain the corresponding DCFormer. DCFormer significantly outperforms Transformer on different architectures and model scales in language modeling, matching the performance of models with ~1.7x-2.0x compute. For example, DCPythia-6.9B outperforms open source Pythia-12B on both pretraining perplexity and downstream task evaluation. The code and models are available at https://github.com/Caiyun-AI/DCFormer.
