Compressing Large Language Models with Automated Sub-Network Search
Rhea Sanjay Sukthanker, Benedikt Staffler, Frank Hutter, Aaron Klein
TL;DR
The paper tackles the scalability challenge of large language models by proposing an automated sub-network search via two-stage neural architecture search to identify Pareto-optimal sparse architectures that balance accuracy and on-device latency. It introduces a joint search space for decoder-only Transformers, a calibrated sampling strategy, importance-based sorting, and integration with parameter-efficient fine-tuning (LoRA) and in-place knowledge distillation to efficiently explore many architectures within a single training run. The method yields Pareto-optimal sub-networks that consistently outperform structural pruning baselines and smaller models across 11 downstream tasks, delivering significant latency reductions (up to about 22%) while preserving or improving accuracy. This approach enables more practical deployment of large models on resource-constrained devices and provides a scalable, automated pipeline for LLM compression that reduces cost and energy use. The work advances automated model optimization by combining NAS, importance-driven sub-network selection, and PEFT to generate adaptable, hardware-aware AI systems.
Abstract
Large Language Models (LLMs) demonstrate exceptional reasoning abilities, enabling strong generalization across diverse tasks such as commonsense reasoning and instruction following. However, as LLMs scale, inference costs become increasingly prohibitive, accumulating significantly over their life cycle. In this paper we consider model compression for LLMs to reduce model size while improving downstream task performance. We phrase this as a neural architecture search problem that automatically prunes structural components, such as attention heads, neurons, and layers by searching for the Pareto-optimal set of sub-networks balancing between performance and on-device latency. Compared to state-of-the-art structural pruning approaches and fine-tuned smaller sub-networks extracted from the pre-trained model, our method achieves upto 9.85% improvement on average on 11 diverse downstream tasks, while achieving up to 22% improvement of on-device latency.
