Flextron: Many-in-One Flexible Large Language Model
Ruisi Cai, Saurav Muralidharan, Greg Heinrich, Hongxu Yin, Zhangyang Wang, Jan Kautz, Pavlo Molchanov
TL;DR
Flextron addresses the cost and inflexibility of deploying large language models by introducing a nestable, many-in-one elastic architecture that supports both elastic MLP and MHA layers. It enables post-training transformation of pretrained LLMs into an input-adaptive elastic network with static and dynamic routers, guided by a surrogate model to stabilize training, and achieves automatic sub-network selection under latency constraints. A sample-efficient elastic continued-training procedure and a staged router-training workflow enable a single pretraining run to reach strong performance, with token costs as low as 7.63% of full pretraining for GPT-3/Llama-2 scales. Empirically, Flextron matches or surpasses end-to-end trained variants and state-of-the-art elastic networks on GPT-3 and Llama-2, while offering substantial latency-parameter trade-offs and practical deployment benefits across diverse data domains.
Abstract
Training modern LLMs is extremely resource intensive, and customizing them for various deployment scenarios characterized by limited compute and memory resources through repeated training is impractical. In this paper, we introduce Flextron, a network architecture and post-training model optimization framework supporting flexible model deployment. The Flextron architecture utilizes a nested elastic structure to rapidly adapt to specific user-defined latency and accuracy targets during inference with no additional fine-tuning required. It is also input-adaptive, and can automatically route tokens through its sub-networks for improved performance and efficiency. We present a sample-efficient training method and associated routing algorithms for systematically transforming an existing trained LLM into a Flextron model. We evaluate Flextron on the GPT-3 and LLama-2 family of LLMs, and demonstrate superior performance over multiple end-to-end trained variants and other state-of-the-art elastic networks, all with a single pretraining run that consumes a mere 7.63% tokens compared to original pretraining.
