Table of Contents
Fetching ...

Matryoshka Representation Learning

Aditya Kusupati, Gantavya Bhatt, Aniket Rege, Matthew Wallingford, Aditya Sinha, Vivek Ramanujan, William Howard-Snyder, Kaifeng Chen, Sham Kakade, Prateek Jain, Ali Farhadi

TL;DR

Matryoshka Representation Learning (MRL) addresses the rigidity of fixed-capacity representations by embedding multiple granularities of information within a single, high-dimensional vector. It optimizes a nested set of dimensions, with $|\mathcal{M}|\le \lfloor\log(d)\rfloor$, training classifiers on each prefix $z_{1:m}$ and optionally tying weights to form MRLE. Empirically, MRL achieves comparable accuracy to independently trained fixed-feature baselines at each nesting size, while enabling adaptive classification and retrieval that dramatically reduces compute (up to $\sim14\times$) and memory footprints on ImageNet-1K and web-scale datasets, and extending across Vision, Vision-Language, and NLP models (e.g., ResNet, ViT, ALIGN, BERT). The approach yields robust performance in out-of-domain settings, improves few-shot tail performance, and provides insights into information bottlenecks and hierarchical class structure, all with open-source code and pretrained models. Overall, MRL enables practical, scalable adaptive deployment for large-scale representation learning without extra inference costs.

Abstract

Learned representations are a central component in modern ML systems, serving a multitude of downstream tasks. When training such representations, it is often the case that computational and statistical constraints for each downstream task are unknown. In this context rigid, fixed capacity representations can be either over or under-accommodating to the task at hand. This leads us to ask: can we design a flexible representation that can adapt to multiple downstream tasks with varying computational resources? Our main contribution is Matryoshka Representation Learning (MRL) which encodes information at different granularities and allows a single embedding to adapt to the computational constraints of downstream tasks. MRL minimally modifies existing representation learning pipelines and imposes no additional cost during inference and deployment. MRL learns coarse-to-fine representations that are at least as accurate and rich as independently trained low-dimensional representations. The flexibility within the learned Matryoshka Representations offer: (a) up to 14x smaller embedding size for ImageNet-1K classification at the same level of accuracy; (b) up to 14x real-world speed-ups for large-scale retrieval on ImageNet-1K and 4K; and (c) up to 2% accuracy improvements for long-tail few-shot classification, all while being as robust as the original representations. Finally, we show that MRL extends seamlessly to web-scale datasets (ImageNet, JFT) across various modalities -- vision (ViT, ResNet), vision + language (ALIGN) and language (BERT). MRL code and pretrained models are open-sourced at https://github.com/RAIVNLab/MRL.

Matryoshka Representation Learning

TL;DR

Matryoshka Representation Learning (MRL) addresses the rigidity of fixed-capacity representations by embedding multiple granularities of information within a single, high-dimensional vector. It optimizes a nested set of dimensions, with , training classifiers on each prefix and optionally tying weights to form MRLE. Empirically, MRL achieves comparable accuracy to independently trained fixed-feature baselines at each nesting size, while enabling adaptive classification and retrieval that dramatically reduces compute (up to ) and memory footprints on ImageNet-1K and web-scale datasets, and extending across Vision, Vision-Language, and NLP models (e.g., ResNet, ViT, ALIGN, BERT). The approach yields robust performance in out-of-domain settings, improves few-shot tail performance, and provides insights into information bottlenecks and hierarchical class structure, all with open-source code and pretrained models. Overall, MRL enables practical, scalable adaptive deployment for large-scale representation learning without extra inference costs.

Abstract

Learned representations are a central component in modern ML systems, serving a multitude of downstream tasks. When training such representations, it is often the case that computational and statistical constraints for each downstream task are unknown. In this context rigid, fixed capacity representations can be either over or under-accommodating to the task at hand. This leads us to ask: can we design a flexible representation that can adapt to multiple downstream tasks with varying computational resources? Our main contribution is Matryoshka Representation Learning (MRL) which encodes information at different granularities and allows a single embedding to adapt to the computational constraints of downstream tasks. MRL minimally modifies existing representation learning pipelines and imposes no additional cost during inference and deployment. MRL learns coarse-to-fine representations that are at least as accurate and rich as independently trained low-dimensional representations. The flexibility within the learned Matryoshka Representations offer: (a) up to 14x smaller embedding size for ImageNet-1K classification at the same level of accuracy; (b) up to 14x real-world speed-ups for large-scale retrieval on ImageNet-1K and 4K; and (c) up to 2% accuracy improvements for long-tail few-shot classification, all while being as robust as the original representations. Finally, we show that MRL extends seamlessly to web-scale datasets (ImageNet, JFT) across various modalities -- vision (ViT, ResNet), vision + language (ALIGN) and language (BERT). MRL code and pretrained models are open-sourced at https://github.com/RAIVNLab/MRL.
Paper Structure (49 sections, 1 equation, 12 figures, 31 tables, 2 algorithms)

This paper contains 49 sections, 1 equation, 12 figures, 31 tables, 2 algorithms.

Figures (12)

  • Figure 1: ${\rm Matryoshka~Representation~Learning}$ is adaptable to any representation learning setup and begets a ${\rm Matryoshka~Representation}$$z$ by optimizing the original loss $\mathcal{L}(.)$ at $O(\log(d))$ chosen representation sizes. ${\rm Matryoshka~Representation}$ can be utilized effectively for adaptive deployment across environments and downstream tasks.
  • Figure 2: ImageNet-1K linear classification accuracy of ResNet50 models. ${\rm MRL}$ is as accurate as the independently trained FF models for every representation size.
  • Figure 3: ImageNet-1K 1-NN accuracy of ResNet50 models measuring the representation quality for downstream task. ${\rm MRL}$ outperforms all the baselines across all representation sizes.
  • Figure 4: ImageNet-1K 1-NN accuracy for ViT-B/16 models trained on JFT-300M & as part of ALIGN. ${\rm MRL}$ scales seamlessly to web-scale with minimal training overhead.
  • Figure 5: Despite optimizing ${\rm MRL}$ only for $O(\log(d))$ dimensions for ResNet50 and ViT-B/16 models; the accuracy in the intermediate dimensions shows interpolating behaviour.
  • ...and 7 more figures