Table of Contents
Fetching ...

Beyond Uniform Query Distribution: Key-Driven Grouped Query Attention

Zohaib Khan, Muhammad Khaquan, Omer Tafveez, Burhanuddin Samiwala, Agha Ali Raza

TL;DR

This work introduces enhancements to GQA, focusing on two novel approaches that deviate from the static nature of grouping: Key-Distributed GQA (KDGQA) and Dynamic Key-Distributed GQA (DGQA), which leverage information from the norms of the key heads to inform query allocation.

Abstract

The Transformer architecture has revolutionized deep learning through its Self-Attention mechanism, which effectively captures contextual information. However, the memory footprint of Self-Attention presents significant challenges for long-sequence tasks. Grouped Query Attention (GQA) addresses this issue by grouping queries and mean-pooling the corresponding key-value heads - reducing the number of overall parameters and memory requirements in a flexible manner without adversely compromising model accuracy. In this work, we introduce enhancements to GQA, focusing on two novel approaches that deviate from the static nature of grouping: Key-Distributed GQA (KDGQA) and Dynamic Key-Distributed GQA (DGQA), which leverage information from the norms of the key heads to inform query allocation. Specifically, KDGQA looks at the ratios of the norms of the key heads during each forward pass, while DGQA examines the ratios of the norms as they evolve through training. Additionally, we present Perturbed GQA (PGQA) as a case-study, which introduces variability in (static) group formation via subtracting noise from the attention maps. Our experiments with up-trained Vision Transformers, for Image Classification on datasets such as CIFAR-10, CIFAR-100, Food101, and Tiny ImageNet, demonstrate the promise of these variants in improving upon the original GQA through more informed and adaptive grouping mechanisms: specifically ViT-L experiences accuracy gains of up to 8% when utilizing DGQA in comparison to GQA and other variants. We further analyze the impact of the number of Key-Value Heads on performance, underscoring the importance of utilizing query-key affinities. Code is available on GitHub.

Beyond Uniform Query Distribution: Key-Driven Grouped Query Attention

TL;DR

This work introduces enhancements to GQA, focusing on two novel approaches that deviate from the static nature of grouping: Key-Distributed GQA (KDGQA) and Dynamic Key-Distributed GQA (DGQA), which leverage information from the norms of the key heads to inform query allocation.

Abstract

The Transformer architecture has revolutionized deep learning through its Self-Attention mechanism, which effectively captures contextual information. However, the memory footprint of Self-Attention presents significant challenges for long-sequence tasks. Grouped Query Attention (GQA) addresses this issue by grouping queries and mean-pooling the corresponding key-value heads - reducing the number of overall parameters and memory requirements in a flexible manner without adversely compromising model accuracy. In this work, we introduce enhancements to GQA, focusing on two novel approaches that deviate from the static nature of grouping: Key-Distributed GQA (KDGQA) and Dynamic Key-Distributed GQA (DGQA), which leverage information from the norms of the key heads to inform query allocation. Specifically, KDGQA looks at the ratios of the norms of the key heads during each forward pass, while DGQA examines the ratios of the norms as they evolve through training. Additionally, we present Perturbed GQA (PGQA) as a case-study, which introduces variability in (static) group formation via subtracting noise from the attention maps. Our experiments with up-trained Vision Transformers, for Image Classification on datasets such as CIFAR-10, CIFAR-100, Food101, and Tiny ImageNet, demonstrate the promise of these variants in improving upon the original GQA through more informed and adaptive grouping mechanisms: specifically ViT-L experiences accuracy gains of up to 8% when utilizing DGQA in comparison to GQA and other variants. We further analyze the impact of the number of Key-Value Heads on performance, underscoring the importance of utilizing query-key affinities. Code is available on GitHub.
Paper Structure (21 sections, 4 equations, 9 figures, 7 tables)

This paper contains 21 sections, 4 equations, 9 figures, 7 tables.

Figures (9)

  • Figure 1: An overview of the relevant mechanisms. MHA associates one query to one key-value head. GQA associates a single key-value head to subgroups of queries such that the group sizes remain constant and the grouping is performed in a static/uniform manner. Our approaches, KDGQA/DGQA, perform the grouping according to the norms of the key heads in a non-uniform manner - note how the darker keys (indicating higher norms) have more queries assigned to them in comparison to the lighter ones.
  • Figure 2: Standard deviation of the EMA-variant incurs fewer spiky, supporting the idea that it is more robust to transient noise.
  • Figure 3: PGQA mitigates similarity bias from GQA, however it wipes out self-similarity patterns seen in attention map of MHA.
  • Figure 4: Scaling the number of Key-Value heads, and how it impacts the loss. Each data point represents a variant on ViT-B, with 12 heads, that was finetuned on CIFAR-100 for 5 epochs, without any uptraining for the sake of time. It is evident that the models perform better when they have the maximal number of Key-Value heads.
  • Figure 5: DGQA is an interesting middle ground between GQA and MHA. Unlike PGQA, heat map of DGQA does not wipe out patterns observed in MHA.
  • ...and 4 more figures