Table of Contents
Fetching ...

MatFormer: Nested Transformer for Elastic Inference

Devvrit, Sneha Kudugunta, Aditya Kusupati, Tim Dettmers, Kaifeng Chen, Inderjit Dhillon, Yulia Tsvetkov, Hannaneh Hajishirzi, Sham Kakade, Ali Farhadi, Prateek Jain

TL;DR

MatFormer introduces a matryoshka-inspired, nested-FFN Transformer that is trained to yield multiple accurate submodels without extra training. By nesting FFN sub-blocks and using Mix'n'Match to combine granularities across layers, a single universal model can serve a spectrum of compute budgets, enabling elastic inference for both language and vision tasks. Empirical results show MatLM and MatViT match or exceed independently trained baselines and enable practical benefits such as faster speculative decoding and adaptive retrieval, with scaling laws akin to standard Transformers. This elastic deployment capability promises reduced latency and deployment costs across diverse hardware environments while preserving accuracy and behavior consistency across submodels.

Abstract

Foundation models are applied in a broad spectrum of settings with different inference constraints, from massive multi-accelerator clusters to resource-constrained standalone mobile devices. However, the substantial costs associated with training these models often limit the number of unique model sizes that can be offered. Consequently, practitioners are compelled to select a model that may not be optimally aligned with their specific latency and cost requirements. We present MatFormer, a novel Transformer architecture designed to provide elastic inference across diverse deployment constraints. MatFormer achieves this by incorporating a nested Feed Forward Network (FFN) block structure within a standard Transformer model. During training, we optimize the parameters of multiple nested FFN blocks with varying sizes, enabling the extraction of hundreds of accurate smaller models without incurring additional computational costs. We empirically validate the efficacy of MatFormer across different model classes (decoders and encoders) and modalities (language and vision), demonstrating its potential for real-world deployment. We show that a 850M decoder-only MatFormer language model (MatLM) allows us to extract multiple smaller models spanning from 582M to 850M parameters, each exhibiting better validation loss and one-shot downstream evaluations than independently trained counterparts. Furthermore, we observe that smaller encoders extracted from a universal MatFormer-based ViT (MatViT) encoder preserve the metric-space structure for adaptive large-scale retrieval. Finally, we showcase that speculative decoding with the accurate and consistent submodels extracted from MatFormer can lead to significant reduction in inference latency. Project website: https://devvrit.github.io/matformer/

MatFormer: Nested Transformer for Elastic Inference

TL;DR

MatFormer introduces a matryoshka-inspired, nested-FFN Transformer that is trained to yield multiple accurate submodels without extra training. By nesting FFN sub-blocks and using Mix'n'Match to combine granularities across layers, a single universal model can serve a spectrum of compute budgets, enabling elastic inference for both language and vision tasks. Empirical results show MatLM and MatViT match or exceed independently trained baselines and enable practical benefits such as faster speculative decoding and adaptive retrieval, with scaling laws akin to standard Transformers. This elastic deployment capability promises reduced latency and deployment costs across diverse hardware environments while preserving accuracy and behavior consistency across submodels.

Abstract

Foundation models are applied in a broad spectrum of settings with different inference constraints, from massive multi-accelerator clusters to resource-constrained standalone mobile devices. However, the substantial costs associated with training these models often limit the number of unique model sizes that can be offered. Consequently, practitioners are compelled to select a model that may not be optimally aligned with their specific latency and cost requirements. We present MatFormer, a novel Transformer architecture designed to provide elastic inference across diverse deployment constraints. MatFormer achieves this by incorporating a nested Feed Forward Network (FFN) block structure within a standard Transformer model. During training, we optimize the parameters of multiple nested FFN blocks with varying sizes, enabling the extraction of hundreds of accurate smaller models without incurring additional computational costs. We empirically validate the efficacy of MatFormer across different model classes (decoders and encoders) and modalities (language and vision), demonstrating its potential for real-world deployment. We show that a 850M decoder-only MatFormer language model (MatLM) allows us to extract multiple smaller models spanning from 582M to 850M parameters, each exhibiting better validation loss and one-shot downstream evaluations than independently trained counterparts. Furthermore, we observe that smaller encoders extracted from a universal MatFormer-based ViT (MatViT) encoder preserve the metric-space structure for adaptive large-scale retrieval. Finally, we showcase that speculative decoding with the accurate and consistent submodels extracted from MatFormer can lead to significant reduction in inference latency. Project website: https://devvrit.github.io/matformer/
Paper Structure (37 sections, 2 equations, 11 figures, 14 tables)

This paper contains 37 sections, 2 equations, 11 figures, 14 tables.

Figures (11)

  • Figure 1: MatFormer introduces nested structure into the Transformer's FFN block & trains all the submodels, enabling free extraction of hundreds of accurate submodels for elastic inference.
  • Figure 2: Validation loss & one-shot downstream evaluation scores for the 850M MatLM & baseline models. Mix'n'Match helps generate accurate and more consistent models from MatLM that lie on the performance-vs-compute curve spanned by the explicitly optimized submodels.
  • Figure 3: We train various decoder-only MatLM models at a range of sizes from 78M to 850M parameters and observe the scaling trends of all granularities (S, M, L, XL) for validation loss and 1-shot downstream evaluation scores. We find that the MatLM-XL models across scales mimic the training trends of Baseline-XL models. Interestingly, we also note that that validation loss and downstream evaluations follow the scaling trends of the XL-models across all granularities.
  • Figure 4: MatViT variants match or outperform standard ViT models on ImageNet-1K classification and provide free extracted models that span the accuracy-compute curve through Mix'n'Match.
  • Figure 5: MatViT natively enables elastic encoders for adaptive retrieval that can be used for real-time query side computation while retaining strong accuracy on ImageNet-1K, unlike the baselines.
  • ...and 6 more figures