Table of Contents
Fetching ...

HiRE: High Recall Approximate Top-$k$ Estimation for Efficient LLM Inference

Yashas Samaga B L, Varun Yerram, Chong You, Srinadh Bhojanapalli, Sanjiv Kumar, Prateek Jain, Praneeth Netrapalli

TL;DR

HiRE targets the memory-bound bottleneck of autoregressive LLM inference by focusing on efficient top-$k$ estimation for softmax and FFN layers. It introduces high-recall approximate top-$k$ estimators (HiRE-LR, HiRE-Q) and a distributed top-$k$ operator (DA-TOP-$k$), along with group sparse FFN and overlap-based enhancements to preserve accuracy while reducing memory transfers. Empirically, HiRE yields end-to-end latency improvements up to about $1.47\times$ on a $1$B parameter model on TPUv5e and $2.27\times$ with multi-device DA-TOP-$k$, with negligible quality loss. The approach demonstrates that high-recall estimation, coupled with structured sparsity and efficient cross-device gathering, can substantially accelerate LLM inference in practical settings.

Abstract

Autoregressive decoding with generative Large Language Models (LLMs) on accelerators (GPUs/TPUs) is often memory-bound where most of the time is spent on transferring model parameters from high bandwidth memory (HBM) to cache. On the other hand, recent works show that LLMs can maintain quality with significant sparsity/redundancy in the feedforward (FFN) layers by appropriately training the model to operate on a top-$k$ fraction of rows/columns (where $k \approx 0.05$), there by suggesting a way to reduce the transfer of model parameters, and hence latency. However, exploiting this sparsity for improving latency is hindered by the fact that identifying top rows/columns is data-dependent and is usually performed using full matrix operations, severely limiting potential gains. To address these issues, we introduce HiRE (High Recall Approximate Top-k Estimation). HiRE comprises of two novel components: (i) a compression scheme to cheaply predict top-$k$ rows/columns with high recall, followed by full computation restricted to the predicted subset, and (ii) DA-TOP-$k$: an efficient multi-device approximate top-$k$ operator. We demonstrate that on a one billion parameter model, HiRE applied to both the softmax as well as feedforward layers, achieves almost matching pretraining and downstream accuracy, and speeds up inference latency by $1.47\times$ on a single TPUv5e device.

HiRE: High Recall Approximate Top-$k$ Estimation for Efficient LLM Inference

TL;DR

HiRE targets the memory-bound bottleneck of autoregressive LLM inference by focusing on efficient top- estimation for softmax and FFN layers. It introduces high-recall approximate top- estimators (HiRE-LR, HiRE-Q) and a distributed top- operator (DA-TOP-), along with group sparse FFN and overlap-based enhancements to preserve accuracy while reducing memory transfers. Empirically, HiRE yields end-to-end latency improvements up to about on a B parameter model on TPUv5e and with multi-device DA-TOP-, with negligible quality loss. The approach demonstrates that high-recall estimation, coupled with structured sparsity and efficient cross-device gathering, can substantially accelerate LLM inference in practical settings.

Abstract

Autoregressive decoding with generative Large Language Models (LLMs) on accelerators (GPUs/TPUs) is often memory-bound where most of the time is spent on transferring model parameters from high bandwidth memory (HBM) to cache. On the other hand, recent works show that LLMs can maintain quality with significant sparsity/redundancy in the feedforward (FFN) layers by appropriately training the model to operate on a top- fraction of rows/columns (where ), there by suggesting a way to reduce the transfer of model parameters, and hence latency. However, exploiting this sparsity for improving latency is hindered by the fact that identifying top rows/columns is data-dependent and is usually performed using full matrix operations, severely limiting potential gains. To address these issues, we introduce HiRE (High Recall Approximate Top-k Estimation). HiRE comprises of two novel components: (i) a compression scheme to cheaply predict top- rows/columns with high recall, followed by full computation restricted to the predicted subset, and (ii) DA-TOP-: an efficient multi-device approximate top- operator. We demonstrate that on a one billion parameter model, HiRE applied to both the softmax as well as feedforward layers, achieves almost matching pretraining and downstream accuracy, and speeds up inference latency by on a single TPUv5e device.
Paper Structure (21 sections, 7 equations, 3 figures, 8 tables, 2 algorithms)

This paper contains 21 sections, 7 equations, 3 figures, 8 tables, 2 algorithms.

Figures (3)

  • Figure 1: HiRE schematic: To compute the top-$k$ elements of $\phi({\mathbf{Z}}^\top \overrightarrow{\mathbf{x}})$, we first compute an approximate top-$k'$ index set $S'$ by using a low rank approximation ${\mathbf{Z}_{\textrm{approx}}} = {\mathbf{Z}}_1 {\mathbf{Z}}_2^\top$. We then compute $\phi({\mathbf{Z}}\vert_{S'}^\top \overrightarrow{\mathbf{x}})$ for ${\mathbf{Z}}$ restricted to $S'$ and then perform top-$k$ operation on that vector.
  • Figure 2: Dynamic overlap of top-$k$ activations across related responses while generating $4$ parallel samples for the same query. On x-axis is the size of union of top-$k$ activations across $4$ generations divided by $4k$, for $k=(0.05)*m$. As we can see, there is substantial fraction of mass away from $1$, suggesting that the top-$k$ activations of related responses have high overlap, which can yield further latency improvements with HiRE.
  • Figure 3: Efficiency of memory transfer vs group sizes: For a tensor of dimension $n \times g \times d$, which we consider as $n$ groups, each with $g$ vectors of dimension $d$, we plot the efficiency of transferring a random (non-contiguous) subset of groups from HBM to cache as we vary the group size $g$ on the x-axis. Efficiency is defined as the time taken by the sparse operation divided by the time taken by an equivalent dense operation moving the same number of bytes. The numbers are computed for Cloud TPUv5e. As is clear from the figure, even small group sizes such as $8$ lead to very high efficiency, motivating the group sparse structure in our feedforward layers.