Table of Contents
Fetching ...

MatMamba: A Matryoshka State Space Model

Abhinav Shukla, Sai Vemprala, Aditya Kusupati, Ashish Kapoor

TL;DR

A state space model which combines Matryoshka-style learning with Mamba2, by modifying the block to contain nested dimensions to enable joint training and adaptive inference, which makes MatMamba a practically viable option for deploying large-scale models in an elastic way based on the available inference compute.

Abstract

State Space Models (SSMs) like Mamba2 are a promising alternative to Transformers, with faster theoretical training and inference times -- especially for long context lengths. Recent work on Matryoshka Representation Learning -- and its application to Transformer backbones in works like MatFormer -- showed how to introduce nested granularities of smaller submodels in one universal elastic model. In this work, we present MatMamba: a state space model which combines Matryoshka-style learning with Mamba2, by modifying the block to contain nested dimensions to enable joint training and adaptive inference. MatMamba allows for efficient and adaptive deployment across various model sizes. We train a single large MatMamba model and are able to get a number of smaller nested models for free -- while maintaining or improving upon the performance of a baseline smaller model trained from scratch. We train language and image models at a variety of parameter sizes from 35M to 1.4B. Our results on ImageNet and FineWeb show that MatMamba models scale comparably to Transformers, while having more efficient inference characteristics. This makes MatMamba a practically viable option for deploying large-scale models in an elastic way based on the available inference compute. Code and models are open sourced at \url{https://github.com/ScaledFoundations/MatMamba}

MatMamba: A Matryoshka State Space Model

TL;DR

A state space model which combines Matryoshka-style learning with Mamba2, by modifying the block to contain nested dimensions to enable joint training and adaptive inference, which makes MatMamba a practically viable option for deploying large-scale models in an elastic way based on the available inference compute.

Abstract

State Space Models (SSMs) like Mamba2 are a promising alternative to Transformers, with faster theoretical training and inference times -- especially for long context lengths. Recent work on Matryoshka Representation Learning -- and its application to Transformer backbones in works like MatFormer -- showed how to introduce nested granularities of smaller submodels in one universal elastic model. In this work, we present MatMamba: a state space model which combines Matryoshka-style learning with Mamba2, by modifying the block to contain nested dimensions to enable joint training and adaptive inference. MatMamba allows for efficient and adaptive deployment across various model sizes. We train a single large MatMamba model and are able to get a number of smaller nested models for free -- while maintaining or improving upon the performance of a baseline smaller model trained from scratch. We train language and image models at a variety of parameter sizes from 35M to 1.4B. Our results on ImageNet and FineWeb show that MatMamba models scale comparably to Transformers, while having more efficient inference characteristics. This makes MatMamba a practically viable option for deploying large-scale models in an elastic way based on the available inference compute. Code and models are open sourced at \url{https://github.com/ScaledFoundations/MatMamba}

Paper Structure

This paper contains 15 sections, 7 equations, 7 figures, 4 tables.

Figures (7)

  • Figure 1: MatMamba introduces a nested Matryoshka kusupati2022matryoshka structure in a Mamba2 dao2024transformers block. We jointly train a few chosen granularities to get a single model from which we can flexibly extract a large number of nested submodels for adaptive inference based on the available deployment compute.
  • Figure 2: MatMamba layers for vision tasks. Similar to a ViT dosovitskiy2020image, we convert an image into a tensor of embedded patches. Because of the causal nature of the Mamba2 block, we suffix the [CLS] token. We intentionally keep the design simple to better study the properties of the MatMamba block.
  • Figure 3: ImageNet-1K Classification: MatMamba-Vision is as accurate as explicitly trained baselines across various constraints while also spanning the accuracy-vs-compute pareto optimal curve through mix'n'match submodels.
  • Figure 4: Inference speed and memory usage for batch size 1 on an H100 for nested MatMamba-Vision models and a ViT baseline. At larger resolutions, the characteristics of MatMamba are better.
  • Figure 5: Adaptive Image Retrieval on ImageNet-1K: Submodels obtained from the largest MatMamba-Vision model preserve the metric space of embeddings resulting in accurate and adaptive query processing at scale while baseline struggles to work across models without distillation.
  • ...and 2 more figures