KV Prediction for Improved Time to First Token
Maxwell Horton, Qingqing Cao, Chenfan Sun, Yanzi Jin, Sachin Mehta, Mohammad Rastegari, Moin Nabi
TL;DR
This work tackles the challenge of high time-to-first-token (TTFT) latency in on-device transformer inference by introducing KV Prediction. A small auxiliary transformer processes the prompt to produce KV_A, and a learned linear predictor maps KV_A to a predicted base-model KV cache, enabling the base model to generate the first token without querying the auxiliary model during generation. The approach yields a pareto-optimal efficiency-accuracy trade-off, demonstrated on TriviaQA and HumanEval with substantial accuracy retention at reduced TTFT FLOPs, and validated by on-device timing on Apple hardware. The authors provide two KV prediction architectures (KVP-C and KVP-LP), a three-term loss with a consistency term, and release their code for reproducibility, making on-device TTFT improvements practical for real-world applications.
Abstract
Inference with transformer-based language models begins with a prompt processing step. In this step, the model generates the first output token and stores the KV cache needed for future generation steps. This prompt processing step can be computationally expensive, taking 10s of seconds or more for billion-parameter models on edge devices when prompt lengths or batch sizes rise. This degrades user experience by introducing significant latency into the model's outputs. To reduce the time spent producing the first output (known as the ``time to first token'', or TTFT) of a pretrained model, we introduce a novel method called KV Prediction. In our method, a small auxiliary model is used to process the prompt and produce an approximation of the KV cache used by a base model. This approximated KV cache is then used with the base model for autoregressive generation without the need to query the auxiliary model again. We demonstrate that our method produces a pareto-optimal efficiency-accuracy trade-off when compared to baselines. On TriviaQA, we demonstrate relative accuracy improvements in the range of $15\%-50\%$ across a range of TTFT FLOPs budgets. We also demonstrate accuracy improvements of up to $30\%$ on HumanEval python code completion at fixed TTFT FLOPs budgets. Additionally, we benchmark models on an Apple M2 Pro CPU and demonstrate that our improvement in FLOPs translates to a TTFT speedup on hardware. We release our code at https://github.com/apple/corenet/tree/main/projects/kv-prediction .
