Table of Contents
Fetching ...

FIRP: Faster LLM inference via future intermediate representation prediction

Pengfei Wu, Jiahao Liu, Zhuocheng Gong, Qifan Wang, Jinpeng Li, Jingang Wang, Xunliang Cai, Dongyan Zhao

TL;DR

A novel speculative decoding method named FIRP is introduced which generates multiple tokens instead of one at each decoding step and is validated by predicting the intermediate hidden states of future tokens and then using these pseudo hidden states to decode future tokens.

Abstract

Recent advancements in Large Language Models (LLMs) have shown remarkable performance across a wide range of tasks. Despite this, the auto-regressive nature of LLM decoding, which generates only a single token per forward propagation, fails to fully exploit the parallel computational power of GPUs, leading to considerable latency. To address this, we introduce a novel speculative decoding method named FIRP which generates multiple tokens instead of one at each decoding step. We achieve this by predicting the intermediate hidden states of future tokens (tokens have not been decoded yet) and then using these pseudo hidden states to decode future tokens, specifically, these pseudo hidden states are predicted with simple linear transformation in intermediate layers of LLMs. Once predicted, they participate in the computation of all the following layers, thereby assimilating richer semantic information. As the layers go deeper, the semantic gap between pseudo and real hidden states is narrowed and it becomes feasible to decode future tokens with high accuracy. To validate the effectiveness of FIRP, we conduct extensive experiments, showing a speedup ratio of 1.9x-3x in several models and datasets, analytical experiments also prove our motivations.

FIRP: Faster LLM inference via future intermediate representation prediction

TL;DR

A novel speculative decoding method named FIRP is introduced which generates multiple tokens instead of one at each decoding step and is validated by predicting the intermediate hidden states of future tokens and then using these pseudo hidden states to decode future tokens.

Abstract

Recent advancements in Large Language Models (LLMs) have shown remarkable performance across a wide range of tasks. Despite this, the auto-regressive nature of LLM decoding, which generates only a single token per forward propagation, fails to fully exploit the parallel computational power of GPUs, leading to considerable latency. To address this, we introduce a novel speculative decoding method named FIRP which generates multiple tokens instead of one at each decoding step. We achieve this by predicting the intermediate hidden states of future tokens (tokens have not been decoded yet) and then using these pseudo hidden states to decode future tokens, specifically, these pseudo hidden states are predicted with simple linear transformation in intermediate layers of LLMs. Once predicted, they participate in the computation of all the following layers, thereby assimilating richer semantic information. As the layers go deeper, the semantic gap between pseudo and real hidden states is narrowed and it becomes feasible to decode future tokens with high accuracy. To validate the effectiveness of FIRP, we conduct extensive experiments, showing a speedup ratio of 1.9x-3x in several models and datasets, analytical experiments also prove our motivations.

Paper Structure

This paper contains 18 sections, 4 equations, 7 figures, 3 tables.

Figures (7)

  • Figure 1: (a) denotes the end-to-end speedup ratio and draft size for different decoding methods, and (b) denotes the number of average accepted tokens per forward propagation. We search for the best tree structure of Medusa and our FIRP using the search algorithm in medusa. All the results are conducted on LLaMA2-Chat-13B and Xsum dataset and set $k$=3.
  • Figure 2: Overview of our method compared with Auto-regressive and Medusa. Our method differs from Medusa because it predicts the intermediate hidden states of future tokens which achieve better prediction accuracy
  • Figure 3: TopK tokens' prediction accuracy using three prediction methods on LLaMA-2-Chat-13B model including directly train different lm-heads on some intermediate layers (denoted as Early exit in the figure), Medusa method and FIRP, The $N$ in the figure is the prediction step (N=2 means we predict the first draft token). It's clear that our method achieve the best prediction accuracy
  • Figure 4: Hidden states similarity between the pseudo hidden states predicted and the original hidden states
  • Figure 5: The first prediction step prediction accuracy on different layers for Vicuna-7b and LlaMa-2-Chat-7b.
  • ...and 2 more figures