Table of Contents
Fetching ...

Inference Optimization of Foundation Models on AI Accelerators

Youngsuk Park, Kailash Budhathoki, Liangfu Chen, Jonas Kübler, Jiaji Huang, Matthäus Kleindessner, Jun Huan, Volkan Cevher, Yida Wang, George Karypis

TL;DR

This work surveys inference optimization techniques for foundation models on AI accelerators, addressing the memory-bound nature of decoder-only, autoregressive generation and the resulting latency and cost challenges. It catalogs system-level optimizations (KV-cache, FlashAttention, continuous batching), structured Transformer architectures (MQA/GQA, MoE, SWT) and model compression (quantization, pruning, distillation), as well as fast decoding methods (speculative decoding) and distributed strategies (tensor/pipeline/sequence/expert parallelism). The contributions include a detailed taxonomy of techniques, hardware considerations, and practical trade-offs between latency, throughput, and accuracy to guide deployment at scale. The findings highlight that combining memory-efficient attention, architecture-aware designs, and lightweight decoding can substantially reduce inference costs while maintaining satisfactory performance, though extreme long-context and cross-device deployments remain open challenges.

Abstract

Powerful foundation models, including large language models (LLMs), with Transformer architectures have ushered in a new era of Generative AI across various industries. Industry and research community have witnessed a large number of new applications, based on those foundation models. Such applications include question and answer, customer services, image and video generation, and code completions, among others. However, as the number of model parameters reaches to hundreds of billions, their deployment incurs prohibitive inference costs and high latency in real-world scenarios. As a result, the demand for cost-effective and fast inference using AI accelerators is ever more higher. To this end, our tutorial offers a comprehensive discussion on complementary inference optimization techniques using AI accelerators. Beginning with an overview of basic Transformer architectures and deep learning system frameworks, we deep dive into system optimization techniques for fast and memory-efficient attention computations and discuss how they can be implemented efficiently on AI accelerators. Next, we describe architectural elements that are key for fast transformer inference. Finally, we examine various model compression and fast decoding strategies in the same context.

Inference Optimization of Foundation Models on AI Accelerators

TL;DR

This work surveys inference optimization techniques for foundation models on AI accelerators, addressing the memory-bound nature of decoder-only, autoregressive generation and the resulting latency and cost challenges. It catalogs system-level optimizations (KV-cache, FlashAttention, continuous batching), structured Transformer architectures (MQA/GQA, MoE, SWT) and model compression (quantization, pruning, distillation), as well as fast decoding methods (speculative decoding) and distributed strategies (tensor/pipeline/sequence/expert parallelism). The contributions include a detailed taxonomy of techniques, hardware considerations, and practical trade-offs between latency, throughput, and accuracy to guide deployment at scale. The findings highlight that combining memory-efficient attention, architecture-aware designs, and lightweight decoding can substantially reduce inference costs while maintaining satisfactory performance, though extreme long-context and cross-device deployments remain open challenges.

Abstract

Powerful foundation models, including large language models (LLMs), with Transformer architectures have ushered in a new era of Generative AI across various industries. Industry and research community have witnessed a large number of new applications, based on those foundation models. Such applications include question and answer, customer services, image and video generation, and code completions, among others. However, as the number of model parameters reaches to hundreds of billions, their deployment incurs prohibitive inference costs and high latency in real-world scenarios. As a result, the demand for cost-effective and fast inference using AI accelerators is ever more higher. To this end, our tutorial offers a comprehensive discussion on complementary inference optimization techniques using AI accelerators. Beginning with an overview of basic Transformer architectures and deep learning system frameworks, we deep dive into system optimization techniques for fast and memory-efficient attention computations and discuss how they can be implemented efficiently on AI accelerators. Next, we describe architectural elements that are key for fast transformer inference. Finally, we examine various model compression and fast decoding strategies in the same context.
Paper Structure (24 sections, 7 figures)

This paper contains 24 sections, 7 figures.

Figures (7)

  • Figure 1: Original Transformer architecture adopted from vaswani2017attention, comprising of an encoder (left) and a decoder (right). Tokens are initially encoded into an embedding space and a positional encoding is used to encode information about the token positions. Modern LLM architectures are decoder-only with a backbone built of repeated layers containing masked attention and a feed forward neural network (FFN). The masked attention first applies linear transformations on a sequence of embeddings to obtain query ($Q$), key ($K$), and value ($V$) matrices and computes $\mathrm{Attention}(Q, K, V) = \mathrm{softmax}(\frac{QK^T}{\sqrt{d_k}})V$ thus relating the tokens to each other (the mask enforces that tokens can only attend to their predecessors). The FFN is applied on each token independently. Both attention and FFN add their outputs onto the embedding, which is passed through the skip connections.
  • Figure 2: Flash Attention by dao2022flashattention. The outer loop iterates over K and V blocks and loads them to fast SRAM. In each block, inner loops iterates over Q blocks, loading them to SRAM, and writing the attention output back to HBM.
  • Figure 3: Types of memory fragmentation by kwon2023efficient. The figure depicts the memory space for decoding two sequences. Internal memory fragmentation is considered to be the allocated KV cache blocks that are not occupied by the sequences. The free memory space that is not allocated is considered to be external memory fragmentation.
  • Figure 4: Overview of grouped-query method by ainslie2023gqa.
  • Figure 5: Instead of the dense feed-forward network layer in the traditional Transformer (left blue), fedus2022switch introduce a sparse Switch FFN layer (right blue). This layer functions independently on the sequence's tokens.
  • ...and 2 more figures