Table of Contents
Fetching ...

3BASiL: An Algorithmic Framework for Sparse plus Low-Rank Compression of LLMs

Mehdi Makni, Xiang Meng, Rahul Mazumder

TL;DR

This work introduces 3BASiL-TM, an efficient one-shot post-training method for decomposition of LLMs that addresses the WikiText2 perplexity gap and achieves over 2.5x faster compression runtime on an A100 GPU compared to SOTA method.

Abstract

Sparse plus Low-Rank $(\mathbf{S} + \mathbf{LR})$ decomposition of Large Language Models (LLMs) has emerged as a promising direction in model compression, aiming to decompose pre-trained model weights into a sum of sparse and low-rank matrices $(\mathbf{W} \approx \mathbf{S} + \mathbf{LR})$. Despite recent progress, existing methods often suffer from substantial performance degradation compared to dense models. In this work, we introduce 3BASiL-TM, an efficient one-shot post-training method for $(\mathbf{S} + \mathbf{LR})$ decomposition of LLMs that addresses this gap. Our approach first introduces a novel 3-Block Alternating Direction Method of Multipliers (ADMM) method, termed 3BASiL, to minimize the layer-wise reconstruction error with convergence guarantees. We then design an efficient transformer-matching (TM) refinement step that jointly optimizes the sparse and low-rank components across transformer layers. This step minimizes a novel memory-efficient loss that aligns outputs at the transformer level. Notably, the TM procedure is universal as it can enhance any $(\mathbf{S} + \mathbf{LR})$ decomposition, including pure sparsity. Our numerical experiments show that 3BASiL-TM reduces the WikiText2 perplexity gap relative to dense LLaMA-8B model by over 30% under a (2:4 Sparse + 64 LR) configuration, compared to prior methods. Moreover, our method achieves over 2.5x faster compression runtime on an A100 GPU compared to SOTA $(\mathbf{S} + \mathbf{LR})$ method. Our code is available at https://github.com/mazumder-lab/3BASiL.

3BASiL: An Algorithmic Framework for Sparse plus Low-Rank Compression of LLMs

TL;DR

This work introduces 3BASiL-TM, an efficient one-shot post-training method for decomposition of LLMs that addresses the WikiText2 perplexity gap and achieves over 2.5x faster compression runtime on an A100 GPU compared to SOTA method.

Abstract

Sparse plus Low-Rank decomposition of Large Language Models (LLMs) has emerged as a promising direction in model compression, aiming to decompose pre-trained model weights into a sum of sparse and low-rank matrices . Despite recent progress, existing methods often suffer from substantial performance degradation compared to dense models. In this work, we introduce 3BASiL-TM, an efficient one-shot post-training method for decomposition of LLMs that addresses this gap. Our approach first introduces a novel 3-Block Alternating Direction Method of Multipliers (ADMM) method, termed 3BASiL, to minimize the layer-wise reconstruction error with convergence guarantees. We then design an efficient transformer-matching (TM) refinement step that jointly optimizes the sparse and low-rank components across transformer layers. This step minimizes a novel memory-efficient loss that aligns outputs at the transformer level. Notably, the TM procedure is universal as it can enhance any decomposition, including pure sparsity. Our numerical experiments show that 3BASiL-TM reduces the WikiText2 perplexity gap relative to dense LLaMA-8B model by over 30% under a (2:4 Sparse + 64 LR) configuration, compared to prior methods. Moreover, our method achieves over 2.5x faster compression runtime on an A100 GPU compared to SOTA method. Our code is available at https://github.com/mazumder-lab/3BASiL.
Paper Structure (23 sections, 4 theorems, 38 equations, 5 figures, 14 tables)

This paper contains 23 sections, 4 theorems, 38 equations, 5 figures, 14 tables.

Key Result

Theorem 1

Let $\left\{\mathbf{S}^{(t)}\right\}_{t=0}^\infty$ and $\left\{\mathbf{L}^{(t)}\right\}_{t=0}^\infty$ be the sequence generated according to update rule eq:finalupdate. Suppose the penalty parameter $\rho_t$ chosen at iteration $t$ is non-decreasing and satisfies $\sum_{t=0}^\infty 1/\rho_t < \infty where $C$ is a constant depending on $\mathbf{X}$, $\widehat{\mathbf{W}}$, $\lambda$, $\rho_0$, and

Figures (5)

  • Figure 1: Overview of the proposed 3BASiL framework. (Left) For each layer in a Transformer, we employ multi-Block ADMM to efficiently decompose weights into high-quality Sparse plus Low-Rank components by minimizing the layer reconstruction objective. (Right) At the Transformer level, we apply gradient-based optimization to jointly refine all sparse and low-rank components across layers to match the original transformer's output, with the resulting low-rank components serving as smart initialization for subsequent LoRA fine-tuning.
  • Figure 2: Our transformer matching (TM) procedure improves any one-shot $(\mathbf{S} + \mathbf{L}\mathbf{R})$ decomposition method (see baselines in \ref{['sec:experimental-results']}) with a small computational overhead. Circled markers represent standard $(\mathbf{S} + \mathbf{L}\mathbf{R})$ methods, while filled markers indicate their TM-enhanced versions. Black arrows illustrate performance gains due to TM. The compression runtimes are reported in hours. Llama3-8B models were run on a A100 GPU, while Llama3.2-3B were run on a L40 GPU. Our proposal 3BASiL-TM, remains significantly faster: (left) over 2$\times$ speedup on an A100 80GB for the Llama3-8B model decomposed to (2:4+64LR) configuration, and (right) over 3$\times$ speedup on an L40 48GB for the Llama3.2-3B model decomposed to (4:8+64LR) configuration (both compared to Hf-ALPS).
  • Figure 3: One-shot C4 perplexity analysis of Llama3-8B under different (N:M + 64LR) configurations.
  • Figure 4: C4 perplexity performance of Llama3-8B & Llama3.2-1B before/after LoRA fine-tuning.
  • Figure 5: Comparison of true loss values introduced in \ref{['eq:original']} across different $(\mathbf{S} + \mathbf{L}\mathbf{R})$ methods. Lower values indicate better optimization quality. 3BASiL consistently outperforms other methods, particularly for attention layers.

Theorems & Definitions (8)

  • Theorem 1
  • proof
  • Lemma A.1
  • Lemma A.2
  • Lemma A.3
  • proof
  • proof
  • proof