Table of Contents
Fetching ...

Optimal Quantization for Matrix Multiplication

Or Ordentlich, Yury Polyanskiy

TL;DR

This work derives fundamental limits and practical lattice-based encoders for quantizing matrices to approximate their products. It establishes a non-asymptotic lower bound for iid Gaussian matrices and constructs universal nested lattice quantizers whose distortion scales with $igl\|ar{A}^ op ar{B}igr ightarrowigl^2$ and Frobenius norms, achieving asymptotic optimality in the Gaussian case. A phase-transition at $R \\approx 0.906$ bit/entry emerges, signaling necessity of Johnson-Lindenstrauss-style dimensionality reduction at low rates, while a practical low-complexity lattice scheme with rotation, dithering, and Hadamard projections yields near-optimal performance. The framework extends to arbitrary matrices via a robust lattice-quantization scheme and provides a concrete path toward fast, memory-bandwidth–bounded matrix multiplication in ML workloads. Collectively, the results quantify the rate-distortion tradeoffs for matrix multiplication and offer implementations that bridge theory and practice for efficient inference in large models.

Abstract

Recent work in machine learning community proposed multiple methods for performing lossy compression (quantization) of large matrices. This quantization is important for accelerating matrix multiplication (main component of large language models), which is often bottlenecked by the speed of loading these matrices from memory. Unlike classical vector quantization and rate-distortion theory, the goal of these new compression algorithms is to be able to approximate not the matrices themselves, but their matrix product. Specifically, given a pair of real matrices $A,B$ an encoder (compressor) is applied to each of them independently producing descriptions with $R$ bits per entry. These representations subsequently are used by the decoder to estimate matrix product $A^\top B$. In this work, we provide a non-asymptotic lower bound on the mean squared error of this approximation (as a function of rate $R$) for the case of matrices $A,B$ with iid Gaussian entries. Algorithmically, we construct a universal quantizer based on nested lattices with an explicit guarantee of approximation error for any (non-random) pair of matrices $A$, $B$ in terms of only Frobenius norms $\|\bar{A}\|_F, \|\bar{B}\|_F$ and $\|\bar{A}^\top \bar{B}\|_F$, where $\bar{A},\bar{B}$ are versions of $A,B$ with zero-centered columns, respectively. For iid Gaussian matrices our quantizer achieves the lower bound and is, thus, asymptotically optimal. A practical low-complexity version of our quantizer achieves performance quite close to optimal. In addition, we derive rate-distortion function for matrix multiplication of iid Gaussian matrices, which exhibits an interesting phase-transition at $R\approx 0.906$ bit/entry, showing necessity of Johnson-Lindestrauss dimensionality reduction (sketching) in the low-rate regime.

Optimal Quantization for Matrix Multiplication

TL;DR

This work derives fundamental limits and practical lattice-based encoders for quantizing matrices to approximate their products. It establishes a non-asymptotic lower bound for iid Gaussian matrices and constructs universal nested lattice quantizers whose distortion scales with and Frobenius norms, achieving asymptotic optimality in the Gaussian case. A phase-transition at bit/entry emerges, signaling necessity of Johnson-Lindenstrauss-style dimensionality reduction at low rates, while a practical low-complexity lattice scheme with rotation, dithering, and Hadamard projections yields near-optimal performance. The framework extends to arbitrary matrices via a robust lattice-quantization scheme and provides a concrete path toward fast, memory-bandwidth–bounded matrix multiplication in ML workloads. Collectively, the results quantify the rate-distortion tradeoffs for matrix multiplication and offer implementations that bridge theory and practice for efficient inference in large models.

Abstract

Recent work in machine learning community proposed multiple methods for performing lossy compression (quantization) of large matrices. This quantization is important for accelerating matrix multiplication (main component of large language models), which is often bottlenecked by the speed of loading these matrices from memory. Unlike classical vector quantization and rate-distortion theory, the goal of these new compression algorithms is to be able to approximate not the matrices themselves, but their matrix product. Specifically, given a pair of real matrices an encoder (compressor) is applied to each of them independently producing descriptions with bits per entry. These representations subsequently are used by the decoder to estimate matrix product . In this work, we provide a non-asymptotic lower bound on the mean squared error of this approximation (as a function of rate ) for the case of matrices with iid Gaussian entries. Algorithmically, we construct a universal quantizer based on nested lattices with an explicit guarantee of approximation error for any (non-random) pair of matrices , in terms of only Frobenius norms and , where are versions of with zero-centered columns, respectively. For iid Gaussian matrices our quantizer achieves the lower bound and is, thus, asymptotically optimal. A practical low-complexity version of our quantizer achieves performance quite close to optimal. In addition, we derive rate-distortion function for matrix multiplication of iid Gaussian matrices, which exhibits an interesting phase-transition at bit/entry, showing necessity of Johnson-Lindestrauss dimensionality reduction (sketching) in the low-rate regime.

Paper Structure

This paper contains 33 sections, 25 theorems, 227 equations, 3 figures, 3 algorithms.

Key Result

Theorem 1

For any $\varepsilon>0$ and sufficiently large $n$, there exist randomized encoders $f_1:\mathbb{R}^{n\times a}\to[2^{na R}]$, $f_2:\mathbb{R}^{n\times b}\to[2^{nbR}]$, and decoders $g: [2^{na R}]\times [2^{nbR}]\to \mathbb{R}^{a\times b}$ and $g_{1-\mathrm{sided}}:[2^{na R}]\times \mathbb{R}^{n\tim

Figures (3)

  • Figure 1: Encoders for matrix multiplication. Each column of $A$ is encoded by the same encoder, and each column of $B$ is encoded by the same encoder. The encoder used for columns of $A$ and that used for columns of $B$ are also the same, except that for $A$ we use the dither vector $Z_1\in\mathbb{R}^{\kappa n}$, whereas for $B$ we use the dither vector $Z_2\in\mathbb{R}^{\kappa n}$. We illustrate the operation of the encoders on the $i$th column of $A$, $a_i\in\mathbb{R}^n$, and on the $j$th column of $B$, $b_j\in\mathbb{R}^n$. The block $S$ corresponds to left multiplication by the rotation matrix $S\in\mathbb{R}^{n\times n}$, and the block $\mathcal{P}_{\kappa n}$ corresponds to projecting the vector $U_i\in\mathbb{R}^n$ (respectively $V_j\in\mathbb{R}^n$) to $\mathbb{R}^{\kappa n}$, $\kappa\in\frac{1}{n}\cdot\{0,1,\ldots,n\}$, by keeping only its first $\kappa n$ coordinates. Here, $\kappa$ is the time-sharing/sparsification parameter, determining the fraction of coordinates in each vector that are actually "described" to the decoder. The lattices $\Lambda_c\subset\Lambda_f\subset\mathbb{R}^{\kappa n}$ are nested. The component $Q_{\Lambda_f}(\cdot)$ is a lattice quantizer which maps a point in $\mathbb{R}^{\kappa n}$ to the closest lattice point in $\Lambda_f$. The component $\bmod\Lambda_c$ maps a point $x\in\mathbb{R}^{\kappa n}$ to $x-Q_{\Lambda_c}(x)\in\mathcal{V}_c$, where $\mathcal{V}_c$ is the Voronoi region of $\Lambda_c$. The binary representation $W_{a_i}$ (respectively $W_{b_j}$) is an encoding of $\tilde{U}_{i,[\kappa n]}\in(\Lambda_f\cap \mathcal{V}_c)\cong\Lambda_f/\Lambda_c$ (respectively $\tilde{V}_{j,[\kappa n]}\in\Lambda_f/\Lambda_c$) using $\log|\Lambda_f/\Lambda_c|$ bits. The scalars $\widehat{\frac{1}{n}\mathbf{1}^\top a_i},\widehat{\|\bar{a}_i\|}$ (respectively, $\widehat{\frac{1}{n}\mathbf{1}^\top b_j},\widehat{\|\bar{b}_j\|}$) are high-resolution descriptions of $\frac{1}{n}\mathbf{1}^\top a_i,\|\bar{a}_i\|$ (respectively, $\frac{1}{n}\mathbf{1}^\top b_j,\|\bar{b}_j\|$), which require only $O(\log n)$ bits. The dither vectors $Z_1,Z_2$ must be known to the decoder. They can be randomly drawn by the encoders and decoder and require sharing randomness between them (in practice, we just store random seed with the matrices). The matrix $S$ need not be known by the decoder. The operations marked in red corresponds to zero-centering the column vectors, and may be avoided altogether. The effect of avoiding those operations on the performance is replacing $\bar{A}$ with $A$ and $\bar{B}$ with $B$ in the MSE upper bounds in Theorems \ref{['thm:generalMatMul']}, \ref{['thm:generalMatMulNoMMSE']} and \ref{['thm:MostgeneralMatMul']}.
  • Figure 2: Decoder for the matrix multiplication problem. We illustrate the estimation of $(A^\top B)_{ij}$. The component $\Lambda_f/\Lambda_c$-decoder maps $\log|\Lambda_f/\Lambda_c|$ bits to points in $\Lambda_f\cap\mathcal{V}_c\subset\mathbb{R}^{\kappa n}$, where $\mathcal{V}_c$ is the Voronoi region of the lattice $\Lambda_c$. The component $\langle\cdot,\cdot\rangle$ computes the inner product $\hat{U}_{i,[\kappa n]}^\top \hat{V}_{j,[\kappa n]}$, and $\alpha\in[0,1]$ is a (MMSE-like) scaling coefficient. The operation marked in red need only be implemented if the encoders implemented the corresponding zero-centering operations marked in red in Figure \ref{['fig:encoders']}. Note that we can estimate the entire product $A^\top B$ by first decoding $\hat{\tilde{A}}=[\hat{U}_{1,[\kappa n]}|\cdots|\hat{U}_{a,[\kappa n]}]$ and $\hat{\tilde{B}}=[\hat{V}_{1,[\kappa n]}|\cdots|\hat{V}_{b,[\kappa n]}]$, computing the matrix $\alpha \hat{\tilde{A}}^\top \hat{\tilde{B}}$, and then computing its Kronecker product with the rank-1 matrix $N$ whose $ij$th entry is $N_{ij}=\frac{1}{n}\widehat{\|\bar{a}_i\|}\widehat{\|\bar{b}_j\|}$, and adding to it the rank 1 matrix $\mu$ whose $ij$th entry is $\mu_{ij}=n\cdot\widehat{\frac{1}{n}\mathbf{1}^\top a_i}\cdot \widehat{\frac{1}{n}\mathbf{1}^\top b_j}$.
  • Figure 3: The approximation error of the $D_3$-based product nested lattice coding scheme with $q=6$, for random iid Gaussian matrices $A,B\in \mathbb{R}^{n\times n}$, $n=3\cdot 2^{11}$. We plot the histogram of the entries of $\frac{1}{\sqrt{n}}(\widehat{A^\top B}-A^\top B)$ in blue. For comparison, we also plot the histogram of the entries of $\frac{1}{\sqrt{n}}(\widehat{A^\top B}-A^\top B)$ for a $3$-bit scalar quantizer in red.

Theorems & Definitions (31)

  • Theorem 1
  • Theorem 2
  • Theorem 3
  • Remark 1: On zero-centering
  • Proposition 1
  • Proposition 2
  • Proposition 3
  • Theorem 4
  • Theorem 5
  • Lemma 1
  • ...and 21 more