Table of Contents
Fetching ...

Quamba2: A Robust and Scalable Post-training Quantization Framework for Selective State Space Models

Hung-Yueh Chiang, Chi-Chih Chang, Natalia Frumkin, Kai-Chiang Wu, Mohamed S. Abdelfattah, Diana Marculescu

TL;DR

Quamba2 addresses the deployment challenge of selective State Space Models by introducing a robust post-training quantization framework that supports W8A8, W4A8, and W4A16 for both Mamba1 and Mamba2 backbones. The approach leverages channel order preservation and activation persistence to apply sort-and-cluster quantization for inputs and per-state-group quantization for input-dependent parameters, complemented by offline Hadamard matrix fusion and cluster-aware weight reordering to maintain compute-invariance. Empirical results show up to 1.3× prefill and 3× generation speedups, together with 4× memory reduction and only around 1.6% average accuracy loss on multiple zero-shot tasks and MMLU, indicating strong generalization and robustness. The framework also supports head-to-toe quantization and mixed-precision strategies to optimize accuracy-latency trade-offs, enabling edge and cloud deployments of large selective SSMs.

Abstract

State Space Models (SSMs) are emerging as a compelling alternative to Transformers because of their consistent memory usage and high performance. Despite this, scaling up SSMs on cloud services or limited-resource devices is challenging due to their storage requirements and computational power. To overcome this, quantizing SSMs with low bit-width data formats can reduce model size and benefit from hardware acceleration. As SSMs are prone to quantization-induced errors, recent efforts have focused on optimizing a particular model or bit-width for efficiency without sacrificing performance. However, distinct bit-width configurations are essential for different scenarios, like W4A8 for boosting large-batch decoding speed, and W4A16 for enhancing generation speed in short prompt applications for a single user. To this end, we present Quamba2, compatible with W8A8, W4A8, and W4A16 for both Mamba1 and Mamba2 backbones, addressing the growing demand for SSM deployment on various platforms. Based on the channel order preserving and activation persistence of SSMs, we propose an offline approach to quantize inputs of a linear recurrence in 8-bit by sorting and clustering for input $x$, combined with a per-state-group quantization for input-dependent parameters $B$ and $C$. To ensure compute-invariance in the SSM output, we rearrange weights offline according to the clustering sequence. The experiments show that Quamba2-8B outperforms two state-of-the-art SSM quantization methods and delivers 1.3$\times$ and 3$\times$ speed-ups in the pre-filling and generation stages, respectively, while offering 4$\times$ memory reduction with only a $1.6\%$ average accuracy drop. The evaluation on MMLU shows the generalizability and robustness of our framework. The code and quantized models will be released at: https://github.com/enyac-group/Quamba.

Quamba2: A Robust and Scalable Post-training Quantization Framework for Selective State Space Models

TL;DR

Quamba2 addresses the deployment challenge of selective State Space Models by introducing a robust post-training quantization framework that supports W8A8, W4A8, and W4A16 for both Mamba1 and Mamba2 backbones. The approach leverages channel order preservation and activation persistence to apply sort-and-cluster quantization for inputs and per-state-group quantization for input-dependent parameters, complemented by offline Hadamard matrix fusion and cluster-aware weight reordering to maintain compute-invariance. Empirical results show up to 1.3× prefill and 3× generation speedups, together with 4× memory reduction and only around 1.6% average accuracy loss on multiple zero-shot tasks and MMLU, indicating strong generalization and robustness. The framework also supports head-to-toe quantization and mixed-precision strategies to optimize accuracy-latency trade-offs, enabling edge and cloud deployments of large selective SSMs.

Abstract

State Space Models (SSMs) are emerging as a compelling alternative to Transformers because of their consistent memory usage and high performance. Despite this, scaling up SSMs on cloud services or limited-resource devices is challenging due to their storage requirements and computational power. To overcome this, quantizing SSMs with low bit-width data formats can reduce model size and benefit from hardware acceleration. As SSMs are prone to quantization-induced errors, recent efforts have focused on optimizing a particular model or bit-width for efficiency without sacrificing performance. However, distinct bit-width configurations are essential for different scenarios, like W4A8 for boosting large-batch decoding speed, and W4A16 for enhancing generation speed in short prompt applications for a single user. To this end, we present Quamba2, compatible with W8A8, W4A8, and W4A16 for both Mamba1 and Mamba2 backbones, addressing the growing demand for SSM deployment on various platforms. Based on the channel order preserving and activation persistence of SSMs, we propose an offline approach to quantize inputs of a linear recurrence in 8-bit by sorting and clustering for input , combined with a per-state-group quantization for input-dependent parameters and . To ensure compute-invariance in the SSM output, we rearrange weights offline according to the clustering sequence. The experiments show that Quamba2-8B outperforms two state-of-the-art SSM quantization methods and delivers 1.3 and 3 speed-ups in the pre-filling and generation stages, respectively, while offering 4 memory reduction with only a average accuracy drop. The evaluation on MMLU shows the generalizability and robustness of our framework. The code and quantized models will be released at: https://github.com/enyac-group/Quamba.

Paper Structure

This paper contains 55 sections, 3 equations, 14 figures, 13 tables.

Figures (14)

  • Figure 1: (Quamba2-8B memory and throughput.) The head-to-toe (H2T) quantization enables the deployment of Mamba2-8B on edge platforms. Quamba2 delivers $3\times$ throughput on Nvidia A5000 and 13 tokens-per-second (TPS) on Nvidia Nano 8G.
  • Figure 2: (SSD flows with sorted heads and the activation persistence.) We sort the head channels prior to applying quantization scaling factors. The orange blocks on the right indicate the activated channels with higher values in the input and output SSD heads. The SSD performs channel-wise calculation thereby retaining the channel order between input $x$ and output $y$, which we call channel order preserving. The blue and green blocks represent the activated states of input-dependent parameters $B$ and $C$. Our study shows that activated channels and states remain consistent across time steps and input samples, a property we denote as channel persistence and state persistence.
  • Figure 3: (Channel order preserving and activation persistence.) We show the activations in the last block of Mamba2-8B. For an input with $t$ tokens, we demonstrate that the $x$ remains sorted by the maximum of the calibrated channel (a). The SSD calculation is channel-wise, so the output channel order $y$ matches the input order $x$ (b). For $B$ and $C$, the activated states remain consistent over time steps $t$ (c-d) and input samples (e-f). We leverage the observations and design our techniques, sort-and-cluster and per-state-group quantization, to increase the quantization precisions for $x$ (a), $B$, and $C$ (c-f).
  • Figure 4: (Sort-and-cluster.) We leverage the channel-persistent property in SSMs to sort the channel with the calibrated maximum (a-c). The sorted heads disentangle the embedding, as shown in (c-1) and (c-2), enabling the clustering on the heads. We cluster the sorted heads into $m$ groups ($m=8$ in (d)), and reorder the weights offline to match the clustering results. Then, we apply the clustering again in each head group to cluster the channels into $n$ groups ($n=4$ in (e)). For each group, a scaling factor is calculated, resulting in $m \times n$ factors used to quantize $x_t$ to 8-bit.
  • Figure 5: (Quamba2 precision.) The detailed precision mapping of W4A8 and W8A8 Quamba2. We reorder the weights offline to match the sorting and clustering indices of $\bar{x}^s_t$, and apply per-state-group quantization on $\bar{B}^g_t$ and $\bar{C}^g_t$.
  • ...and 9 more figures