Table of Contents
Fetching ...

PRESERVE: Prefetching Model Weights and KV-Cache in Distributed LLM Serving

Ahmet Caner Yüzügüler, Jiawei Zhuang, Lukas Cavigelli

TL;DR

PRESERVE tackles the memory bandwidth and inter-device communication bottlenecks in distributed LLM inference by prefetching model weights and KV-cache from off-chip memory into on-chip L2 caches in parallel with inter-device communication. It is implemented as a graph-optimization framework that automatically inserts prefetching operators into the computation graph, enabling overlap with allreduce without requiring changes to user code. The authors demonstrate up to $1.6\times$ end-to-end speedups across modern open-source LLMs and perform a design-space exploration showing the optimal L2 cache size shifts from 8 MB to 104 MB when prefetching is used, delivering about a $1.25\times$ improvement in performance per cost. These results suggest that on-chip memory-aware prefetching can significantly improve the scalability and efficiency of distributed LLM inference on current AI accelerators, and that hardware design can further amplify these gains.

Abstract

Large language models (LLMs) are typically served from clusters of GPUs/NPUs that consist of large number of devices. Unfortunately, communication between these devices incurs significant overhead, increasing the inference latency and cost while limiting the scalability. Prior work addressed this issue by overlapping communication with compute, but has severe limitations due to the data dependencies between these operations. In this paper, we propose PRESERVE, a novel framework that prefetches model weights and KV-cache from off-chip HBM memory to the on-chip cache of AI accelerators during the communication operations, which offers various advantages and performance improvements compared to prior methods. Through extensive experiments conducted on commercial AI accelerators, we demonstrate up to 1.6x end-to-end speedup on state-of-the-art, open-source LLMs. Additionally, we perform a design space exploration that identifies the optimal hardware configuration for the proposed method, showing a further 1.25x improvement in performance per cost by selecting the optimal L2 cache size. Our results show that PRESERVE has the potential to mitigate the memory bottlenecks and communication overheads, offering a solution to improve the performance and scalability of the LLM inference systems.

PRESERVE: Prefetching Model Weights and KV-Cache in Distributed LLM Serving

TL;DR

PRESERVE tackles the memory bandwidth and inter-device communication bottlenecks in distributed LLM inference by prefetching model weights and KV-cache from off-chip memory into on-chip L2 caches in parallel with inter-device communication. It is implemented as a graph-optimization framework that automatically inserts prefetching operators into the computation graph, enabling overlap with allreduce without requiring changes to user code. The authors demonstrate up to end-to-end speedups across modern open-source LLMs and perform a design-space exploration showing the optimal L2 cache size shifts from 8 MB to 104 MB when prefetching is used, delivering about a improvement in performance per cost. These results suggest that on-chip memory-aware prefetching can significantly improve the scalability and efficiency of distributed LLM inference on current AI accelerators, and that hardware design can further amplify these gains.

Abstract

Large language models (LLMs) are typically served from clusters of GPUs/NPUs that consist of large number of devices. Unfortunately, communication between these devices incurs significant overhead, increasing the inference latency and cost while limiting the scalability. Prior work addressed this issue by overlapping communication with compute, but has severe limitations due to the data dependencies between these operations. In this paper, we propose PRESERVE, a novel framework that prefetches model weights and KV-cache from off-chip HBM memory to the on-chip cache of AI accelerators during the communication operations, which offers various advantages and performance improvements compared to prior methods. Through extensive experiments conducted on commercial AI accelerators, we demonstrate up to 1.6x end-to-end speedup on state-of-the-art, open-source LLMs. Additionally, we perform a design space exploration that identifies the optimal hardware configuration for the proposed method, showing a further 1.25x improvement in performance per cost by selecting the optimal L2 cache size. Our results show that PRESERVE has the potential to mitigate the memory bottlenecks and communication overheads, offering a solution to improve the performance and scalability of the LLM inference systems.
Paper Structure (29 sections, 10 figures, 2 tables, 1 algorithm)

This paper contains 29 sections, 10 figures, 2 tables, 1 algorithm.

Figures (10)

  • Figure 1: Conceptual comparison between A) vanilla inference, B) inference with GEMM+Allreduce fused kernels Rashidi21Hoefler08Punniyamurthy24Chang24WangWei23, C) PRESERVE (this work).
  • Figure 2: Memory footprint of Attention and MLP layers of various LLMs for varying number of devices. The models are assumed to be distributed using tensor parallelism, and both weights and KV-cache are stored in int8 precision. Batch size and context length are taken as 8 and 16k, respectively. Horizontal lines represent the L2 capacity of various state-of-the-art NPU and GPUs, namely Ascend 910B ascend, TPUv5p Vahdat2023tpuv5p, TPUv5e Vahdat2023tpuv5e, and H100 h100.
  • Figure 3: An overview of the proposed PRESERVE framework.
  • Figure 4: Speedups obtained with the proposed method for varying batch size and maximum sequence length for various models. Both activation and weights are in int8 precision. The experiments are performed on four NPUs. N/A indicates insufficient HBM capacity.
  • Figure 5: Comparison over the fused GEMM+Allreduce operator baseline for various batch size (y-axis) and max. sequence lengths (x-axis) for Llama3-70b in fp16. The red-colored squares indicate that PRESERVE is faster than the fused GEMM+Allreduce. N/A denotes out-of-memory.
  • ...and 5 more figures