Weighted Grouped Query Attention in Transformers
Sai Sena Chinnakonduru, Astarag Mohapatra
TL;DR
WGQA targets memory-bound inference in large autoregressive transformers by introducing learnable per-head weights for the key and value heads in decoder attention, enabling a weighted aggregation during finetuning. The method extends GQA with lightweight parameters and several weighting variants, and shows about 0.53% average improvement over GQA on several datasets, with performance converging to standard MHA at inference and no inference overhead. Scaling experiments on T5-small and T5-base indicate bigger models yield larger gains, supporting the proposed approach as a memory-efficient path for larger LMs. A statistical analysis confirms the learned weights differ meaningfully from mean-pooled GQA, and initialization to head means consistently helps. Overall, WGQA provides a practical, train-time parameter-efficient way to improve memory-efficient attention in decoder blocks.
Abstract
The attention mechanism forms the foundational blocks for transformer language models. Recent approaches show that scaling the model achieves human-level performance. However, with increasing demands for scaling and constraints on hardware memory, the inference costs of these models remain high. To reduce the inference time, Multi-Query Attention (MQA) and Grouped-Query Attention (GQA) were proposed in (Shazeer, 2019) and (Ainslieet al., 2023) respectively. In this paper, we propose a variation of Grouped-Query Attention, termed Weighted Grouped-Query Attention (WGQA). We introduced new learnable parameters for each key and value head in the T5 decoder attention blocks, enabling the model to take a weighted average during finetuning. Our model achieves an average of 0.53% improvement over GQA, and the performance converges to traditional Multi-head attention (MHA) with no additional overhead during inference. We evaluated the introduction of these parameters and subsequent finetuning informs the model about the grouping mechanism during training, thereby enhancing performance. Additionally, we demonstrate the scaling laws in our analysis by comparing the results between T5-small and T5-base architecture.
