SLoPe: Double-Pruned Sparse Plus Lazy Low-Rank Adapter Pretraining of LLMs
Mohammad Mozaffari, Amir Yazdanbakhsh, Zhao Zhang, Maryam Mehri Dehnavi
TL;DR
SLoPe introduces a double-pruned sparse plus lazy low-rank adapter pretraining scheme for LLMs, combining a static $N:M$ sparsity pattern with a backward pass that is pruned twice to accelerate both forward and backward computations while preserving accuracy. The method adds low-rank adapters only in the final 1% of pretraining, decomposed as $W_{dense} = W_{sparse} + LR$, to boost capacity with minimal overhead. Key contributions include convergence-guaranteed double-pruned backward pass, lazy low-rank adapters, and optimized CUDA kernels that enable end-to-end speedups of up to $1.25\times$ for training and $1.54\times$ for inference, with memory reductions of up to $0.63\times$ and $0.61\times$ respectively. Empirical results on GPT2 and BERT-like setups show improved pretraining perplexities and downstream task performance compared to prior sparse pretraining methods, validating the practicality of sparse+low-rank pretraining for very large models.
Abstract
We propose SLoPe, a Double-Pruned Sparse Plus Lazy Low-rank Adapter Pretraining method for LLMs that improves the accuracy of sparse LLMs while accelerating their pretraining and inference and reducing their memory footprint. Sparse pretraining of LLMs reduces the accuracy of the model, to overcome this, prior work uses dense models during fine-tuning. SLoPe improves the accuracy of sparsely pretrained models by adding low-rank adapters in the final 1% iterations of pretraining without adding significant overheads to the model pretraining and inference. In addition, SLoPe uses a double-pruned backward pass formulation that prunes the transposed weight matrix using N:M sparsity structures to enable an accelerated sparse backward pass. SLoPe accelerates the training and inference of models with billions of parameters up to $1.25\times$ and $1.54\times$ respectively (OPT-33B and OPT-66B) while reducing their memory usage by up to $0.63\times$ and $0.61\times$ for training and inference respectively.
