Table of Contents
Fetching ...

SARATHI: Efficient LLM Inference by Piggybacking Decodes with Chunked Prefills

Amey Agrawal, Ashish Panwar, Jayashree Mohan, Nipun Kwatra, Bhargav S. Gulavani, Ramachandran Ramjee

TL;DR

This work addresses inefficiencies in large language model inference arising from memory-bound decoding and pipeline bubbles in pipeline-parallel deployments. It introduces SARATHI, which combines chunked-prefills and decode-maximal batching to form uniform, compute-saturated batches that allow decodes to piggyback on prefills, thereby reusing model weights and reducing idle GPU time. Empirical results show substantial gains across models and hardware: decode throughput improvements up to 10x and end-to-end throughput gains up to 1.91x in multi-GPU pipeline configurations, with significant bubble reductions. The approach is applicable across various model sizes and hardware setups, offering a practical path to more scalable and efficient LLM inference in real-world serving environments.

Abstract

Large Language Model (LLM) inference consists of two distinct phases - prefill phase which processes the input prompt and decode phase which generates output tokens autoregressively. While the prefill phase effectively saturates GPU compute at small batch sizes, the decode phase results in low compute utilization as it generates one token at a time per request. The varying prefill and decode times also lead to imbalance across micro-batches when using pipeline parallelism, resulting in further inefficiency due to bubbles. We present SARATHI to address these challenges. SARATHI employs chunked-prefills, which splits a prefill request into equal sized chunks, and decode-maximal batching, which constructs a batch using a single prefill chunk and populates the remaining slots with decodes. During inference, the prefill chunk saturates GPU compute, while the decode requests 'piggyback' and cost up to an order of magnitude less compared to a decode-only batch. Chunked-prefills allows constructing multiple decode-maximal batches from a single prefill request, maximizing coverage of decodes that can piggyback. Furthermore, the uniform compute design of these batches ameliorates the imbalance between micro-batches, significantly reducing pipeline bubbles. Our techniques yield significant improvements in inference performance across models and hardware. For the LLaMA-13B model on A6000 GPU, SARATHI improves decode throughput by up to 10x, and accelerates end-to-end throughput by up to 1.33x. For LLaMa-33B on A100 GPU, we achieve 1.25x higher end-to-end-throughput and up to 4.25x higher decode throughput. When used with pipeline parallelism on GPT-3, SARATHI reduces bubbles by 6.29x, resulting in an end-to-end throughput improvement of 1.91x.

SARATHI: Efficient LLM Inference by Piggybacking Decodes with Chunked Prefills

TL;DR

This work addresses inefficiencies in large language model inference arising from memory-bound decoding and pipeline bubbles in pipeline-parallel deployments. It introduces SARATHI, which combines chunked-prefills and decode-maximal batching to form uniform, compute-saturated batches that allow decodes to piggyback on prefills, thereby reusing model weights and reducing idle GPU time. Empirical results show substantial gains across models and hardware: decode throughput improvements up to 10x and end-to-end throughput gains up to 1.91x in multi-GPU pipeline configurations, with significant bubble reductions. The approach is applicable across various model sizes and hardware setups, offering a practical path to more scalable and efficient LLM inference in real-world serving environments.

Abstract

Large Language Model (LLM) inference consists of two distinct phases - prefill phase which processes the input prompt and decode phase which generates output tokens autoregressively. While the prefill phase effectively saturates GPU compute at small batch sizes, the decode phase results in low compute utilization as it generates one token at a time per request. The varying prefill and decode times also lead to imbalance across micro-batches when using pipeline parallelism, resulting in further inefficiency due to bubbles. We present SARATHI to address these challenges. SARATHI employs chunked-prefills, which splits a prefill request into equal sized chunks, and decode-maximal batching, which constructs a batch using a single prefill chunk and populates the remaining slots with decodes. During inference, the prefill chunk saturates GPU compute, while the decode requests 'piggyback' and cost up to an order of magnitude less compared to a decode-only batch. Chunked-prefills allows constructing multiple decode-maximal batches from a single prefill request, maximizing coverage of decodes that can piggyback. Furthermore, the uniform compute design of these batches ameliorates the imbalance between micro-batches, significantly reducing pipeline bubbles. Our techniques yield significant improvements in inference performance across models and hardware. For the LLaMA-13B model on A6000 GPU, SARATHI improves decode throughput by up to 10x, and accelerates end-to-end throughput by up to 1.33x. For LLaMa-33B on A100 GPU, we achieve 1.25x higher end-to-end-throughput and up to 4.25x higher decode throughput. When used with pipeline parallelism on GPT-3, SARATHI reduces bubbles by 6.29x, resulting in an end-to-end throughput improvement of 1.91x.
Paper Structure (30 sections, 1 equation, 13 figures, 4 tables)

This paper contains 30 sections, 1 equation, 13 figures, 4 tables.

Figures (13)

  • Figure 1: Example two-stage pipeline parallel schedule. (a) In prior solutions like Orca orca, pipeline bubbles are common due to varying prompt and decode compute times. Further, decodes are highly inefficient (decode cost-per-token is order-of-magnitude higher than Prefill). (b) Sarathi significantly reduces pipeline bubbles and enables more efficient piggybacked decodes.
  • Figure 2: High-level architecture of a decoder block.
  • Figure 3: Per-token prefill and decode time with different batch sizes (sequence length = 1024) for LLaMa-13B on A6000 GPU. Prefill saturates GPU compute even at batch size of 1 and results in almost constant per-token time across batch sizes. Decode under-utilizes GPU compute and costs as much as 200$\times$ prefill for batch size 1. The incremental cost of linear operators for decode is almost zero as batch size increases. The attention cost does not benefit from batch size as it is memory-bound.
  • Figure 4: Impact of the arithmetic intensity (bottom) on the throughput (top) of prefills and decodes for LLaMA-13B on A6000 GPU.
  • Figure 5: Pipeline bubbles in LLM inference A 2-way PP iteration-level schedule orca across 4 requests (A,B,C,D) shows the existence of pipeline bubbles due to non-uniform batch execution times.
  • ...and 8 more figures