Mixture of Attention Heads: Selecting Attention Heads Per Token
Xiaofeng Zhang, Yikang Shen, Zeyu Huang, Jie Zhou, Wenge Rong, Zhang Xiong
TL;DR
The paper tackles the efficiency-scaling tension in Transformer models by introducing Mixture of Attention Heads (MoA), which fuses mixture-of-experts with multi-head attention. A routing network selects a sparse subset of attention experts per token, and two MoEs handle query and output projections while sharing key/value projections to reduce cost; auxiliary losses balance expert usage and stabilize routing. Empirical results on WMT14 translation and WikiText-103 MLM show MoA achieving strong performance with favorable compute and parameter efficiency, and analyses reveal perceptible head specialization and balanced expert loads. The work demonstrates a scalable, interpretable approach to expanding attention capacity without commensurate increases in computation, setting a path for larger, cheaper-to-run Transformer variants.
Abstract
Mixture-of-Experts (MoE) networks have been proposed as an efficient way to scale up model capacity and implement conditional computing. However, the study of MoE components mostly focused on the feedforward layer in Transformer architecture. This paper proposes the Mixture of Attention Heads (MoA), a new architecture that combines multi-head attention with the MoE mechanism. MoA includes a set of attention heads that each has its own set of parameters. Given an input, a router dynamically selects a subset of $k$ attention heads per token. This conditional computation schema allows MoA to achieve stronger performance than the standard multi-head attention layer. Furthermore, the sparsely gated MoA can easily scale up the number of attention heads and the number of parameters while preserving computational efficiency. In addition to the performance improvements, MoA also automatically differentiates heads' utilities, providing a new perspective to discuss the model's interpretability. We conducted experiments on several important tasks, including Machine Translation and Masked Language Modeling. Experiments have shown promising results on several tasks against strong baselines that involve large and very deep models.
