Composer: A Search Framework for Hybrid Neural Architecture Design
Bilge Acun, Prasoon Sinha, Newsha Ardalani, Sangmin Bae, Alicia Golden, Chien-Yu Lin, Meghana Madhyastha, Fei Sun, Neeraja J. Yadwadkar, Carole-Jean Wu
TL;DR
Composer presents a principled hybrid neural architecture search framework for large language models that interleaves attention and MLP primitives and extrapolates small-scale discoveries to ~1000x larger sizes. It decomposes the pipeline into a Search Engine, Evaluator, Aggregator, and Extrapolator, enabling efficient exploration via One-Shot Bayesian optimization, N_c clustering, and stacking/stretching extrapolation. The resulting Composite LLMs, characterized by a 1:2 Attention-to-MLP ratio, outperform Llama 3.2 across 350M-3B scales in validation loss and downstream tasks, while delivering faster training and inference. The framework demonstrates robustness across datasets and scales, with strong rank correlation between small-scale searches and large-scale performance, and it points to future extensions with additional primitives to broaden the search space.
Abstract
Hybrid model architectures that combine computational primitives (e.g., Attention, MLP) in different ratios have shown promising performance beyond Transformers. Some studies have shown that different interleavings of primitives can affect model quality as well. However, prior works explore the hybrid model architecture design space manually. Due to the large design space and training costs, discovering hybrid models that combine key computational primitives for pre-training is challenging. In this work, we take a principled approach in designing a modular hybrid model architecture search framework -- Composer. Composer explores model architectures at a small scale and extrapolates the top-performing model architectures to a larger scale using our proposed scaling strategies. Using Composer, we discover new hybrid LLM architectures that outperform Llama 3.2. Compared to Llama 3.2 and previous state-of-the-art baselines, the new model architectures consistently reduce validation loss at parameter scales of 350M-3B and improve evaluation accuracy on the downstream tasks by up to 2.8-8.3% (1.1-3.1% on average) while improving both training and inference efficiency.
