SwitchHead: Accelerating Transformers with Mixture-of-Experts Attention
Róbert Csordás, Piotr Piękos, Kazuki Irie, Jürgen Schmidhuber
TL;DR
SwitchHead introduces a Mixture-of-Experts approach to the Transformer attention layer, drastically reducing the number of attention matrices computed by allowing multiple expert projections for values and outputs while sharing keys and queries. By using a non-competitive sigmoid selection, SwitchHead computes attention with significantly fewer active heads and can be combined with MoE-based MLP layers to form SwitchAll. Empirical results across multiple datasets and model sizes show SwitchHead achieves perplexity comparable to parameter-matched dense Transformers with much lower compute and memory usage, and SwitchAll often surpasses baselines under the same parameter budgets. The work demonstrates stable training without extra regularizers and provides insights into attention map redundancy and interpretable expert selections, with practical implications for resource-constrained deployment and scalable language modeling.
Abstract
Despite many recent works on Mixture of Experts (MoEs) for resource-efficient Transformer language models, existing methods mostly focus on MoEs for feedforward layers. Previous attempts at extending MoE to the self-attention layer fail to match the performance of the parameter-matched baseline. Our novel SwitchHead is an effective MoE method for the attention layer that successfully reduces both the compute and memory requirements, achieving wall-clock speedup, while matching the language modeling performance of the baseline Transformer. Our novel MoE mechanism allows SwitchHead to compute up to 8 times fewer attention matrices than the standard Transformer. SwitchHead can also be combined with MoE feedforward layers, resulting in fully-MoE "SwitchAll" Transformers. For our 262M parameter model trained on C4, SwitchHead matches the perplexity of standard models with only 44% compute and 27% memory usage. Zero-shot experiments on downstream tasks confirm the performance of SwitchHead, e.g., achieving more than 3.5% absolute improvements on BliMP compared to the baseline with an equal compute resource.
