Table of Contents
Fetching ...

Flextron: Many-in-One Flexible Large Language Model

Ruisi Cai, Saurav Muralidharan, Greg Heinrich, Hongxu Yin, Zhangyang Wang, Jan Kautz, Pavlo Molchanov

TL;DR

Flextron addresses the cost and inflexibility of deploying large language models by introducing a nestable, many-in-one elastic architecture that supports both elastic MLP and MHA layers. It enables post-training transformation of pretrained LLMs into an input-adaptive elastic network with static and dynamic routers, guided by a surrogate model to stabilize training, and achieves automatic sub-network selection under latency constraints. A sample-efficient elastic continued-training procedure and a staged router-training workflow enable a single pretraining run to reach strong performance, with token costs as low as 7.63% of full pretraining for GPT-3/Llama-2 scales. Empirically, Flextron matches or surpasses end-to-end trained variants and state-of-the-art elastic networks on GPT-3 and Llama-2, while offering substantial latency-parameter trade-offs and practical deployment benefits across diverse data domains.

Abstract

Training modern LLMs is extremely resource intensive, and customizing them for various deployment scenarios characterized by limited compute and memory resources through repeated training is impractical. In this paper, we introduce Flextron, a network architecture and post-training model optimization framework supporting flexible model deployment. The Flextron architecture utilizes a nested elastic structure to rapidly adapt to specific user-defined latency and accuracy targets during inference with no additional fine-tuning required. It is also input-adaptive, and can automatically route tokens through its sub-networks for improved performance and efficiency. We present a sample-efficient training method and associated routing algorithms for systematically transforming an existing trained LLM into a Flextron model. We evaluate Flextron on the GPT-3 and LLama-2 family of LLMs, and demonstrate superior performance over multiple end-to-end trained variants and other state-of-the-art elastic networks, all with a single pretraining run that consumes a mere 7.63% tokens compared to original pretraining.

Flextron: Many-in-One Flexible Large Language Model

TL;DR

Flextron addresses the cost and inflexibility of deploying large language models by introducing a nestable, many-in-one elastic architecture that supports both elastic MLP and MHA layers. It enables post-training transformation of pretrained LLMs into an input-adaptive elastic network with static and dynamic routers, guided by a surrogate model to stabilize training, and achieves automatic sub-network selection under latency constraints. A sample-efficient elastic continued-training procedure and a staged router-training workflow enable a single pretraining run to reach strong performance, with token costs as low as 7.63% of full pretraining for GPT-3/Llama-2 scales. Empirically, Flextron matches or surpasses end-to-end trained variants and state-of-the-art elastic networks on GPT-3 and Llama-2, while offering substantial latency-parameter trade-offs and practical deployment benefits across diverse data domains.

Abstract

Training modern LLMs is extremely resource intensive, and customizing them for various deployment scenarios characterized by limited compute and memory resources through repeated training is impractical. In this paper, we introduce Flextron, a network architecture and post-training model optimization framework supporting flexible model deployment. The Flextron architecture utilizes a nested elastic structure to rapidly adapt to specific user-defined latency and accuracy targets during inference with no additional fine-tuning required. It is also input-adaptive, and can automatically route tokens through its sub-networks for improved performance and efficiency. We present a sample-efficient training method and associated routing algorithms for systematically transforming an existing trained LLM into a Flextron model. We evaluate Flextron on the GPT-3 and LLama-2 family of LLMs, and demonstrate superior performance over multiple end-to-end trained variants and other state-of-the-art elastic networks, all with a single pretraining run that consumes a mere 7.63% tokens compared to original pretraining.
Paper Structure (32 sections, 12 equations, 14 figures, 5 tables)

This paper contains 32 sections, 12 equations, 14 figures, 5 tables.

Figures (14)

  • Figure 1: High-level overview of the Flextron framework. As shown in the top half of the Figure, Flextron enables fast, zero-shot generation of hardware and input-adaptive sub-networks targeting various accuracy, latency and parameter constraints. The bottom half of the figure demonstrates how we convert a trained LLM into an elastic network with input-adaptive routing.
  • Figure 2: Illustration of the elastic continued-training phase.
  • Figure 3: Illustration of how routers are trained via a Surrogate Model (SM). The Surrogate Model is trained to approximate the LLM language loss value given only routers logits. If the error of the SM is smaller than a predefined threshold, the routers are updated. Updates are based on (i) the latency loss, ensuring the requested latency matches the real overall latency via a Lookup Table (LUT), and (ii) the loss from minimization of the SM output. The SM serves as a proxy for the full model's language loss and allows for simpler backpropagation due to its smaller size. Once the routers are trained, we discard the SM and finetune the LLM and routers jointly.
  • Figure 4: The Flextron-Llama2-7B model family demonstrates superior MMLU hendrycks2020measuring performance compared to both open-source models and existing post-hoc compression methods. Specifically, we compare against models from the Pythia biderman2023pythia family and the OpenLLaMA-v2 openlm2023openllama family. Additionally, our method is compared with Sheared-LLaMA xia2023sheared, Compresso guo2023compresso, LLM-Pruner ma2023llm, SliceGPT ashkboos2024slicegpt, and LaCo yang2024laco. $\times$ suffix indicates the remaining latency of the model.
  • Figure 5: Pareto curves for language modeling loss vs latency (left) and # non-embedding parameters (right). The curve is fitted by the model scaling equation. Flextron achieves superior performance to Matformer and even end-to-end-trained smaller models (843M). The performance of the model is evaluated by language modeling validation loss and averaged over $7$ representative datasets: ($1$) English datasets:Arxiv, Books3 pile, Wikipedia wikidump, ($2$) multilingual datasets: Korean, German languages, and ($3$) code data: HTML, JAVA. We measure model latency with the Megatron framework shoeybi2019megatron using a batch size of 2 and sequence length of 4096 in the context prefilling stage on NVIDIA A100 GPU.
  • ...and 9 more figures