Leaner Transformers: More Heads, Less Depth
Hemanth Saratchandran, Damien Teney, Simon Lucey
TL;DR
The paper tackles the problem of overparameterization in transformers by proposing that multi-head attention inherently improves the conditioning of attention blocks, allowing deeper models to be replaced with leaner architectures. It introduces a theoretical framework showing that with enough heads, the attention matrix becomes well-conditioned (low condition number), facilitating optimization. Empirically, the authors demonstrate across vision and language tasks that increasing heads while reducing depth yields substantial parameter and memory savings (up to ~30-50%) with comparable or improved accuracy on ImageNet-1k, GLUE, TinyStories, and LRA. The findings suggest a practical design principle for efficient transformers and raise questions about the fundamental limits and scalability of lean architectures in large-scale settings.
Abstract
Transformers have reshaped machine learning by utilizing attention mechanisms to capture complex patterns in large datasets, leading to significant improvements in performance. This success has contributed to the belief that "bigger means better", leading to ever-increasing model sizes. This paper challenge this ideology by showing that many existing transformers might be unnecessarily oversized. We discover a theoretical principle that redefines the role of multi-head attention. An important benefit of the multiple heads is in improving the conditioning of the attention block. We exploit this theoretical insight and redesign popular architectures with an increased number of heads. The improvement in the conditioning proves so significant in practice that model depth can be decreased, reducing the parameter count by up to 30-50% while maintaining accuracy. We obtain consistent benefits across a variety of transformer-based architectures of various scales, on tasks in computer vision (ImageNet-1k) as well as language and sequence modeling (GLUE benchmark, TinyStories, and the Long-Range Arena benchmark).
