Table of Contents
Fetching ...

Getting Free Bits Back from Rotational Symmetries in LLMs

Jiajun He, Gergely Flamich, José Miguel Hernández-Lobato

TL;DR

This paper proposes a format based on bits-back coding for storing rotationally symmetric Transformer weights more efficiently than the usual array layout at the same floating-point precision.

Abstract

Current methods for compressing neural network weights, such as decomposition, pruning, quantization, and channel simulation, often overlook the inherent symmetries within these networks and thus waste bits on encoding redundant information. In this paper, we propose a format based on bits-back coding for storing rotationally symmetric Transformer weights more efficiently than the usual array layout at the same floating-point precision. We evaluate our method on Large Language Models (LLMs) pruned by SliceGPT (Ashkboos et al., 2024) and achieve a 3-5% reduction in total bit usage for free across different model sizes and architectures without impacting model performance within a certain numerical precision.

Getting Free Bits Back from Rotational Symmetries in LLMs

TL;DR

This paper proposes a format based on bits-back coding for storing rotationally symmetric Transformer weights more efficiently than the usual array layout at the same floating-point precision.

Abstract

Current methods for compressing neural network weights, such as decomposition, pruning, quantization, and channel simulation, often overlook the inherent symmetries within these networks and thus waste bits on encoding redundant information. In this paper, we propose a format based on bits-back coding for storing rotationally symmetric Transformer weights more efficiently than the usual array layout at the same floating-point precision. We evaluate our method on Large Language Models (LLMs) pruned by SliceGPT (Ashkboos et al., 2024) and achieve a 3-5% reduction in total bit usage for free across different model sizes and architectures without impacting model performance within a certain numerical precision.
Paper Structure (11 sections, 17 equations, 5 figures, 1 table, 6 algorithms)

This paper contains 11 sections, 17 equations, 5 figures, 1 table, 6 algorithms.

Figures (5)

  • Figure 1: Visualization of a Standard Transformer Block and a SliceGPT-Pruned Transformer Block. (a) The standard Transformer block first maps the input through an attention layer; then it applies LayerNorm ba2016layer and a 1-layer Feedforward Network (FFN). Two residual connections are added after the attention layer and the FFN. Here, we adopt the notation by ashkboosslicegpt, where ${\mathbf{M}}={\mathbf{I}} - \frac{1}{D}\bm{1}\bm{1}^\top$ represents the operation that subtracts the mean in each row. (b) SliceGPT ashkboosslicegpt first absorbs ${\mathbf{M}}$ and $\text{diag}(\bm{\alpha})$ into the weights before and after the normalization layer. It then rotates these weights by applying PCA to the hidden states, aligning them with their principal components (PCs). Subsequently, SliceGPT prunes rows and columns corresponding to the least significant PCs, indicated by gray shadows. It is important to note that the weights in (b) differ from those in (a) due to the absorption of ${\mathbf{M}}$ and $\text{diag}(\bm{\alpha})$ and the rotation. Additionally, as SliceGPT introduces two weight matrices ${\mathbf{Q}}_{\text{skip\_mlp}}$ and ${\mathbf{Q}}_\text{skip\_att}$ to the skip connections, it carries more rotational symmetries compared to the standard Transformer in (a). For a more detailed explanation of SliceGPT, please refer to Figure 4 in ashkboosslicegpt.
  • Figure 2: Histogram and empirical CDF of the error between the reconstructed weights and the original weights before encoding, using ${\mathbf{W}}_o$ in the final layer of OPT-6.7B as an example. The pattern in this plot generalizes well to other weights and models. As shown, only a small fraction of the weights exhibit relatively large deviations. Therefore, we can allocate a negligible number of bits to transmit the positions and true values of these weights, effectively correcting the error caused by numerical inaccuracies.
  • Figure 3: The effectiveness of the correction codes with different thresholds. Setting a threshold around 0.005-0.01 can effectively rescue all performance drops due to numerical inaccuracies while still significantly reducing bits compared to the compression rate without bits-back.
  • Figure : Rotate Transformer to its Canonical Direction.
  • Figure : Bits-back Encoding for transformers (processed by SliceGPT). We use red to represent adding bits to the bitstream; green to represent removing bits from the bitstream.

Theorems & Definitions (1)

  • Remark 3.1