HiRE: High Recall Approximate Top-$k$ Estimation for Efficient LLM Inference
Yashas Samaga B L, Varun Yerram, Chong You, Srinadh Bhojanapalli, Sanjiv Kumar, Prateek Jain, Praneeth Netrapalli
TL;DR
HiRE targets the memory-bound bottleneck of autoregressive LLM inference by focusing on efficient top-$k$ estimation for softmax and FFN layers. It introduces high-recall approximate top-$k$ estimators (HiRE-LR, HiRE-Q) and a distributed top-$k$ operator (DA-TOP-$k$), along with group sparse FFN and overlap-based enhancements to preserve accuracy while reducing memory transfers. Empirically, HiRE yields end-to-end latency improvements up to about $1.47\times$ on a $1$B parameter model on TPUv5e and $2.27\times$ with multi-device DA-TOP-$k$, with negligible quality loss. The approach demonstrates that high-recall estimation, coupled with structured sparsity and efficient cross-device gathering, can substantially accelerate LLM inference in practical settings.
Abstract
Autoregressive decoding with generative Large Language Models (LLMs) on accelerators (GPUs/TPUs) is often memory-bound where most of the time is spent on transferring model parameters from high bandwidth memory (HBM) to cache. On the other hand, recent works show that LLMs can maintain quality with significant sparsity/redundancy in the feedforward (FFN) layers by appropriately training the model to operate on a top-$k$ fraction of rows/columns (where $k \approx 0.05$), there by suggesting a way to reduce the transfer of model parameters, and hence latency. However, exploiting this sparsity for improving latency is hindered by the fact that identifying top rows/columns is data-dependent and is usually performed using full matrix operations, severely limiting potential gains. To address these issues, we introduce HiRE (High Recall Approximate Top-k Estimation). HiRE comprises of two novel components: (i) a compression scheme to cheaply predict top-$k$ rows/columns with high recall, followed by full computation restricted to the predicted subset, and (ii) DA-TOP-$k$: an efficient multi-device approximate top-$k$ operator. We demonstrate that on a one billion parameter model, HiRE applied to both the softmax as well as feedforward layers, achieves almost matching pretraining and downstream accuracy, and speeds up inference latency by $1.47\times$ on a single TPUv5e device.
