Table of Contents
Fetching ...

Matryoshka Quantization

Pranav Nair, Puranjay Datta, Jeff Dean, Prateek Jain, Aditya Kusupati

TL;DR

MatQuant introduces a multi-scale quantization framework that leverages the nested Matryoshka structure of integer bit-widths to train a single model that can be deployed at multiple precisions ($8$, $4$, and $2$ bits) and interpolated to intermediate widths. By unifying Quantization Aware Training (QAT) and OmniQuant within a shared MSB framework, MatQuant jointly optimizes across bit-widths and employs a slicing-based operator to extract lower-precision models without retraining, while also enabling layerwise Mix'n'Match for flexible accuracy-cost trade-offs. Empirical results on Gemma-2 and Mistral-7B show that MatQuant preserves int8/int4 accuracy comparable to explicit baselines and yields substantial gains for int2 (up to ~7% on some tasks), with strong interpolative performance to int6 and int3. The approach supports co-distillation and a Single Precision variant to further boost low-bit performance, and it highlights deployment advantages such as elastic serving across hardware with varying bit-width support, though extending to floating point remains a future challenge.

Abstract

Quantizing model weights is critical for reducing the communication and inference costs of large models. However, quantizing models -- especially to low precisions like int4 or int2 -- requires a trade-off in model quality; int2, in particular, is known to severely degrade model quality. Consequently, practitioners are often forced to maintain multiple models with different quantization levels or serve a single model that best satisfies the quality-latency trade-off. On the other hand, integer data types, such as int8, inherently possess a nested (Matryoshka) structure where smaller bit-width integers, like int4 or int2, are nested within the most significant bits. Leveraging this insight, in this paper, we propose Matryoshka Quantization (MatQuant), a novel multi-scale quantization technique that alleviates the aforementioned challenge. This technique allows us to train and maintain a single quantized model but serve it with the precision demanded by the deployment. Furthermore, leveraging MatQuant's co-training and co-distillation regularization, int2 precision models extracted by MatQuant outperform standard int2 quantization by up to to 4% and 7% with OmniQuant and QAT as base algorithms respectively. Finally, we demonstrate that by using an extra bit to represent outliers, a model with an effective precision of 2.05-bit gives an additional 6% improvement with OmniQuant as the base algorithm.

Matryoshka Quantization

TL;DR

MatQuant introduces a multi-scale quantization framework that leverages the nested Matryoshka structure of integer bit-widths to train a single model that can be deployed at multiple precisions (, , and bits) and interpolated to intermediate widths. By unifying Quantization Aware Training (QAT) and OmniQuant within a shared MSB framework, MatQuant jointly optimizes across bit-widths and employs a slicing-based operator to extract lower-precision models without retraining, while also enabling layerwise Mix'n'Match for flexible accuracy-cost trade-offs. Empirical results on Gemma-2 and Mistral-7B show that MatQuant preserves int8/int4 accuracy comparable to explicit baselines and yields substantial gains for int2 (up to ~7% on some tasks), with strong interpolative performance to int6 and int3. The approach supports co-distillation and a Single Precision variant to further boost low-bit performance, and it highlights deployment advantages such as elastic serving across hardware with varying bit-width support, though extending to floating point remains a future challenge.

Abstract

Quantizing model weights is critical for reducing the communication and inference costs of large models. However, quantizing models -- especially to low precisions like int4 or int2 -- requires a trade-off in model quality; int2, in particular, is known to severely degrade model quality. Consequently, practitioners are often forced to maintain multiple models with different quantization levels or serve a single model that best satisfies the quality-latency trade-off. On the other hand, integer data types, such as int8, inherently possess a nested (Matryoshka) structure where smaller bit-width integers, like int4 or int2, are nested within the most significant bits. Leveraging this insight, in this paper, we propose Matryoshka Quantization (MatQuant), a novel multi-scale quantization technique that alleviates the aforementioned challenge. This technique allows us to train and maintain a single quantized model but serve it with the precision demanded by the deployment. Furthermore, leveraging MatQuant's co-training and co-distillation regularization, int2 precision models extracted by MatQuant outperform standard int2 quantization by up to to 4% and 7% with OmniQuant and QAT as base algorithms respectively. Finally, we demonstrate that by using an extra bit to represent outliers, a model with an effective precision of 2.05-bit gives an additional 6% improvement with OmniQuant as the base algorithm.

Paper Structure

This paper contains 41 sections, 8 equations, 7 figures, 30 tables.

Figures (7)

  • Figure 2: Mix'n'Match on Gemma-2 9B model trained using ${\rm MatQuant}$ with OmniQuant allows elastic accuracy-vs-cost model extraction for free during deployment.
  • Figure 3: Mix'n'Match on Gemma-2 9B model trained using ${\rm Extra\text{ } Precison\text{ }MatQuant}$ with OmniQuant as the base algorithm allows elastic pareto-optimal accuracy-vs-cost model extraction for free during deployment.
  • Figure 4: The Figure presents the weight distribution for Gemma-2 9B when trained with ${\rm Single\text{ } Precison\text{ }MatQuant}$ for int2 quantization. The right-shifted quantized weight distribution is a consequence of ${\rm Single\text{ } Precison\text{ }MatQuant}$'s training mechanism that heavily optimizes for the first 2 MSBs of the int8 representation.
  • Figure : (a)
  • Figure : (a)
  • ...and 2 more figures