Table of Contents
Fetching ...

Tandem Transformers for Inference Efficient LLMs

Aishwarya P S, Pranav Ajit Nair, Yashas Samaga, Toby Boyd, Sanjiv Kumar, Prateek Jain, Praneeth Netrapalli

TL;DR

This work introduces a novel architecture, Tandem transformers, that uniquely combines a small autoregressive model and a large model operating in block mode (processing multiple tokens simultaneously), and incorporates the tandem model within the speculative decoding (SPEED) framework where the large model validates tokens from the small model.

Abstract

The autoregressive nature of conventional large language models (LLMs) inherently limits inference speed, as tokens are generated sequentially. While speculative and parallel decoding techniques attempt to mitigate this, they face limitations: either relying on less accurate smaller models for generation or failing to fully leverage the base LLM's representations. We introduce a novel architecture, Tandem transformers, to address these issues. This architecture uniquely combines (1) a small autoregressive model and (2) a large model operating in block mode (processing multiple tokens simultaneously). The small model's predictive accuracy is substantially enhanced by granting it attention to the large model's richer representations. On the PaLM2 pretraining dataset, a tandem of PaLM2-Bison and PaLM2-Gecko demonstrates a 3.3% improvement in next-token prediction accuracy over a standalone PaLM2-Gecko, offering a 1.16x speedup compared to a PaLM2-Otter model with comparable downstream performance. We further incorporate the tandem model within the speculative decoding (SPEED) framework where the large model validates tokens from the small model. This ensures that the Tandem of PaLM2-Bison and PaLM2-Gecko achieves substantial speedup (around 1.14x faster than using vanilla PaLM2-Gecko in SPEED) while maintaining identical downstream task accuracy.

Tandem Transformers for Inference Efficient LLMs

TL;DR

This work introduces a novel architecture, Tandem transformers, that uniquely combines a small autoregressive model and a large model operating in block mode (processing multiple tokens simultaneously), and incorporates the tandem model within the speculative decoding (SPEED) framework where the large model validates tokens from the small model.

Abstract

The autoregressive nature of conventional large language models (LLMs) inherently limits inference speed, as tokens are generated sequentially. While speculative and parallel decoding techniques attempt to mitigate this, they face limitations: either relying on less accurate smaller models for generation or failing to fully leverage the base LLM's representations. We introduce a novel architecture, Tandem transformers, to address these issues. This architecture uniquely combines (1) a small autoregressive model and (2) a large model operating in block mode (processing multiple tokens simultaneously). The small model's predictive accuracy is substantially enhanced by granting it attention to the large model's richer representations. On the PaLM2 pretraining dataset, a tandem of PaLM2-Bison and PaLM2-Gecko demonstrates a 3.3% improvement in next-token prediction accuracy over a standalone PaLM2-Gecko, offering a 1.16x speedup compared to a PaLM2-Otter model with comparable downstream performance. We further incorporate the tandem model within the speculative decoding (SPEED) framework where the large model validates tokens from the small model. This ensures that the Tandem of PaLM2-Bison and PaLM2-Gecko achieves substantial speedup (around 1.14x faster than using vanilla PaLM2-Gecko in SPEED) while maintaining identical downstream task accuracy.
Paper Structure (15 sections, 2 equations, 3 figures, 10 tables)

This paper contains 15 sections, 2 equations, 3 figures, 10 tables.

Figures (3)

  • Figure 1: Training of Tandem Transformers with a block length $\gamma = 2$. $\mathrm{Atn}_L^{(\ell(j)+1)}$ and $\mathrm{FF}_L^{(\ell(j)+1)}$ denote the attention and feedforward blocks in the $(\ell(j)+1)^\textrm{th}$ layer of $\mathcal{M}_L$, while $\mathrm{Atn}_L^{(j+1)}$ and $\mathrm{FF}_L^{(j+1)}$ denote those of $(j+1)^\textrm{th}$ layer of $\mathcal{M}_S$. $\mathcal{M}_L$ processes the tokens as a standard decoder Transformer. $\mathcal{M}_S$ on the other hand processes the tokens in the $\left(\frac{i}{\gamma}\right)^\textrm{th}$ block using its own representations ${y}_{i}^{(j)}$ and ${y}_{i+1}^{(j)}$, but while attending to the representations of all tokens from the previous block from the $(\ell(j)+1)^\textrm{th}$ layer of $\mathcal{M}_L$ passed through a feedforward layer $\mathrm{FF}_{\textrm{Tandem}}^{(j)}$.
  • Figure 2: Inference of Tandem Transformers with free token from the primary model $\mathcal{M}_L$. (left) First block prediction. (right) Second block prediction. Given the query The Himalayas are a mountain range separating the, $\mathcal{M}_L$ first processes this query and produces the first response token plains. When we use this prediction from $\mathcal{M}_L$, this is directly fed as an input to the secondary model $\mathcal{M}_S$, which autoregressively produces of India for the first block with $\gamma = 2$. In the second block, the entire response from the first block plains of India is fed to the primary model $\mathcal{M}_L$, which again produces the next response token from, and then the secondary model $\mathcal{M}_S$ produces the next two tokens of the block the Tibetan autoregressively. The eventual output of the model will be plains of India from the Tibetan ....
  • Figure 3: Inference of Tandem Transformers without free token from the primary model $\mathcal{M}_L$. (left) First block prediction. (right) Second block prediction. Given the same query The Himalayas are a mountain range separating the as in Figure \ref{['fig:inference-free-token']}, here, $\mathcal{M}_L$ first processes this query except the last token the. The last token is passed as an input to the secondary model $\mathcal{M}_S$, which attends to $\mathcal{M}_L$ representations for all past tokens, and produces the first block of responses Tibetan plateau autoregressively. In the second block, $\mathcal{M}_L$ processes the Tibetan in a block mode while plateau is passed as an input to $\mathcal{M}_S$, which then autoregressively generate the next block of response from India. This eventually leads to a response of Tibetan plateau from India....