Table of Contents
Fetching ...

KV Prediction for Improved Time to First Token

Maxwell Horton, Qingqing Cao, Chenfan Sun, Yanzi Jin, Sachin Mehta, Mohammad Rastegari, Moin Nabi

TL;DR

This work tackles the challenge of high time-to-first-token (TTFT) latency in on-device transformer inference by introducing KV Prediction. A small auxiliary transformer processes the prompt to produce KV_A, and a learned linear predictor maps KV_A to a predicted base-model KV cache, enabling the base model to generate the first token without querying the auxiliary model during generation. The approach yields a pareto-optimal efficiency-accuracy trade-off, demonstrated on TriviaQA and HumanEval with substantial accuracy retention at reduced TTFT FLOPs, and validated by on-device timing on Apple hardware. The authors provide two KV prediction architectures (KVP-C and KVP-LP), a three-term loss with a consistency term, and release their code for reproducibility, making on-device TTFT improvements practical for real-world applications.

Abstract

Inference with transformer-based language models begins with a prompt processing step. In this step, the model generates the first output token and stores the KV cache needed for future generation steps. This prompt processing step can be computationally expensive, taking 10s of seconds or more for billion-parameter models on edge devices when prompt lengths or batch sizes rise. This degrades user experience by introducing significant latency into the model's outputs. To reduce the time spent producing the first output (known as the ``time to first token'', or TTFT) of a pretrained model, we introduce a novel method called KV Prediction. In our method, a small auxiliary model is used to process the prompt and produce an approximation of the KV cache used by a base model. This approximated KV cache is then used with the base model for autoregressive generation without the need to query the auxiliary model again. We demonstrate that our method produces a pareto-optimal efficiency-accuracy trade-off when compared to baselines. On TriviaQA, we demonstrate relative accuracy improvements in the range of $15\%-50\%$ across a range of TTFT FLOPs budgets. We also demonstrate accuracy improvements of up to $30\%$ on HumanEval python code completion at fixed TTFT FLOPs budgets. Additionally, we benchmark models on an Apple M2 Pro CPU and demonstrate that our improvement in FLOPs translates to a TTFT speedup on hardware. We release our code at https://github.com/apple/corenet/tree/main/projects/kv-prediction .

KV Prediction for Improved Time to First Token

TL;DR

This work tackles the challenge of high time-to-first-token (TTFT) latency in on-device transformer inference by introducing KV Prediction. A small auxiliary transformer processes the prompt to produce KV_A, and a learned linear predictor maps KV_A to a predicted base-model KV cache, enabling the base model to generate the first token without querying the auxiliary model during generation. The approach yields a pareto-optimal efficiency-accuracy trade-off, demonstrated on TriviaQA and HumanEval with substantial accuracy retention at reduced TTFT FLOPs, and validated by on-device timing on Apple hardware. The authors provide two KV prediction architectures (KVP-C and KVP-LP), a three-term loss with a consistency term, and release their code for reproducibility, making on-device TTFT improvements practical for real-world applications.

Abstract

Inference with transformer-based language models begins with a prompt processing step. In this step, the model generates the first output token and stores the KV cache needed for future generation steps. This prompt processing step can be computationally expensive, taking 10s of seconds or more for billion-parameter models on edge devices when prompt lengths or batch sizes rise. This degrades user experience by introducing significant latency into the model's outputs. To reduce the time spent producing the first output (known as the ``time to first token'', or TTFT) of a pretrained model, we introduce a novel method called KV Prediction. In our method, a small auxiliary model is used to process the prompt and produce an approximation of the KV cache used by a base model. This approximated KV cache is then used with the base model for autoregressive generation without the need to query the auxiliary model again. We demonstrate that our method produces a pareto-optimal efficiency-accuracy trade-off when compared to baselines. On TriviaQA, we demonstrate relative accuracy improvements in the range of across a range of TTFT FLOPs budgets. We also demonstrate accuracy improvements of up to on HumanEval python code completion at fixed TTFT FLOPs budgets. Additionally, we benchmark models on an Apple M2 Pro CPU and demonstrate that our improvement in FLOPs translates to a TTFT speedup on hardware. We release our code at https://github.com/apple/corenet/tree/main/projects/kv-prediction .

Paper Structure

This paper contains 20 sections, 1 equation, 7 figures, 8 tables.

Figures (7)

  • Figure 1: Time to First Token (TTFT) and ratio of TTFT to generation time for an OpenELM 3B model on an M2 Pro CPU everymac with 32GB of RAM. We evaluate at batch sizes 1, 4, and 8.
  • Figure 2: (a) An overview of our training and inference method. (b) Our inference method. (c) Our training method.
  • Figure 3: Efficiency-accuracy trade-off of our KV Prediction method (KVP-C, KVP-LP) on TriviaQA compared to baselines. The x-axis shows the relative reduction in FLOPs compared to the base network, and the y-axis shows the relative accuracy retention compared to the base network. (Left): Results using a base network of OpenELM 1.1B for KV Prediction. For KV Prediction models (green), points are annotated with the auxiliary network used. For example, the leftmost green "x" corresponds to OE1.1B-KVP-LP-0.75, and the leftmost green "+" corresponds to OE1.1B-KVP-C-450M. For OpenELM baselines (blue), points are annotated with the OpenELM variant used. All other baselines use variations of token pruning with different rates on OpenELM 1.1B (Right): Results using a base network of OpenELM 3B.
  • Figure 4: Efficiency-accuracy trade-off of our KV Prediction method (KVP-C, KVP-LP) compared to baselines on HumanEval python code completion. The x-axis shows the relative speedup in FLOPs compared to OpenELM 1.1B, and the y-axis shows the relative accuracy retention compared to OpenELM 1.1B. (Left): HumanEval Pass@1. (Right): HumanEval Pass@10 (Note that results for KVP-C and OE are overlapping.)
  • Figure 5: (\ref{['fig:timing-tqa']}): Accuracy on the TriviaQA dataset compared to benchmarked time to first token on CPU. (\ref{['fig:timing-oe1b-450m']}): The time to first token of our KVP prediction model OE1.1B-KVP-C-450M compared to OpenELM 1.1B and OpenELM450M.
  • ...and 2 more figures