Table of Contents
Fetching ...

MLKV: Multi-Layer Key-Value Heads for Memory Efficient Transformer Decoding

Zayd Muhammad Kawakibi Zuhri, Muhammad Farid Adilazuarda, Ayu Purwarianti, Alham Fikri Aji

TL;DR

Evaluations on various NLP benchmarks and inference metrics using uptrained Pythia-160M variants demonstrate that MLKV significantly reduces memory usage with minimal performance loss, reducing KV cache size down to a factor of 6x compared to MQA.

Abstract

Auto-regressive inference of transformers benefit greatly from Key-Value (KV) caching, but can lead to major memory bottlenecks as model size, batch size, and sequence length grow at scale. We introduce Multi-Layer Key-Value (MLKV) sharing, a novel approach extending KV sharing across transformer layers to reduce memory usage beyond what was possible with Multi-Query Attention (MQA) and Grouped-Query Attention (GQA). Evaluations on various NLP benchmarks and inference metrics using uptrained Pythia-160M variants demonstrate that MLKV significantly reduces memory usage with minimal performance loss, reducing KV cache size down to a factor of 6x compared to MQA. These results highlight MLKV's potential for efficient deployment of transformer models at scale. We provide code at https://github.com/zaydzuhri/pythia-mlkv

MLKV: Multi-Layer Key-Value Heads for Memory Efficient Transformer Decoding

TL;DR

Evaluations on various NLP benchmarks and inference metrics using uptrained Pythia-160M variants demonstrate that MLKV significantly reduces memory usage with minimal performance loss, reducing KV cache size down to a factor of 6x compared to MQA.

Abstract

Auto-regressive inference of transformers benefit greatly from Key-Value (KV) caching, but can lead to major memory bottlenecks as model size, batch size, and sequence length grow at scale. We introduce Multi-Layer Key-Value (MLKV) sharing, a novel approach extending KV sharing across transformer layers to reduce memory usage beyond what was possible with Multi-Query Attention (MQA) and Grouped-Query Attention (GQA). Evaluations on various NLP benchmarks and inference metrics using uptrained Pythia-160M variants demonstrate that MLKV significantly reduces memory usage with minimal performance loss, reducing KV cache size down to a factor of 6x compared to MQA. These results highlight MLKV's potential for efficient deployment of transformer models at scale. We provide code at https://github.com/zaydzuhri/pythia-mlkv
Paper Structure (21 sections, 4 equations, 6 figures, 4 tables)

This paper contains 21 sections, 4 equations, 6 figures, 4 tables.

Figures (6)

  • Figure 1: Simplified overview of current KV sharing methods, vanilla MHA (top left), MQA (bottom left), and GQA (top right). All of them share KV heads within the same layer. Our proposed KV sharing scheme MLKV (bottom right) shares KV heads between layers.
  • Figure 2: Detailed illustration of attention using the different KV sharing mechanisms. Vanilla MHA (left) has a key-value head for each query head. GQA (top middle) here with 2 groups of heads. MQA (bottom middle) only has one key-value head for all query heads. MLKV (right) can share the one key-value head from the bottom layer, to the query heads of some layer above it.
  • Figure 3: Line plots to visualize the inference time memory measurements in terms of the batch sizes that can be achieved by each model. The red 'X' indicates that beyond that batch size, an out-of-memory error will occur.
  • Figure 4: Average accuracy vs lowest recorded memory usage (this is at a minimum batch size but memory scales the same way as it increases). Pareto optimality resides in the left upper corner of the plot.
  • Figure 5: Inference time memory measurements of the 410M parameter models. The red 'X' indicates that beyond that batch size, an OOM error will occur.
  • ...and 1 more figures