Table of Contents
Fetching ...

Sirius: Contextual Sparsity with Correction for Efficient LLMs

Yang Zhou, Zhuoming Chen, Zhaozhuo Xu, Victoria Lin, Beidi Chen

TL;DR

Sirius is introduced, an efficient correction mechanism, which significantly recovers CS models quality on reasoning tasks while maintaining its efficiency gain, and is evaluated on 6 models with 8 difficult generation tasks in reasoning, math, and coding and shows consistent effectiveness and efficiency.

Abstract

With the blossom of large language models (LLMs), inference efficiency becomes increasingly important. Various approximation methods are proposed to reduce the cost at inference time. Contextual Sparsity (CS) is appealing for its training-free nature and its ability to reach a higher compression ratio seemingly without quality degradation. However, after a comprehensive evaluation of contextual sparsity methods on various complex generation tasks, we find that although CS succeeds in prompt-understanding tasks, CS significantly degrades the model performance for reasoning, deduction, and knowledge-based tasks. Despite the gap in end-to-end accuracy, we observed that sparse models often share general problem-solving logic and require only a few token corrections to recover the original model performance. This paper introduces Sirius, an efficient correction mechanism, which significantly recovers CS models quality on reasoning tasks while maintaining its efficiency gain. Sirius is evaluated on 6 models with 8 difficult generation tasks in reasoning, math, and coding and shows consistent effectiveness and efficiency. Also, we carefully develop a system implementation for Sirius and show that Sirius achieves roughly 20% reduction in latency for 8B model on-chip and 35% reduction for 70B model offloading. We open-source our implementation of Sirius at https://github.com/Infini-AI-Lab/Sirius.git.

Sirius: Contextual Sparsity with Correction for Efficient LLMs

TL;DR

Sirius is introduced, an efficient correction mechanism, which significantly recovers CS models quality on reasoning tasks while maintaining its efficiency gain, and is evaluated on 6 models with 8 difficult generation tasks in reasoning, math, and coding and shows consistent effectiveness and efficiency.

Abstract

With the blossom of large language models (LLMs), inference efficiency becomes increasingly important. Various approximation methods are proposed to reduce the cost at inference time. Contextual Sparsity (CS) is appealing for its training-free nature and its ability to reach a higher compression ratio seemingly without quality degradation. However, after a comprehensive evaluation of contextual sparsity methods on various complex generation tasks, we find that although CS succeeds in prompt-understanding tasks, CS significantly degrades the model performance for reasoning, deduction, and knowledge-based tasks. Despite the gap in end-to-end accuracy, we observed that sparse models often share general problem-solving logic and require only a few token corrections to recover the original model performance. This paper introduces Sirius, an efficient correction mechanism, which significantly recovers CS models quality on reasoning tasks while maintaining its efficiency gain. Sirius is evaluated on 6 models with 8 difficult generation tasks in reasoning, math, and coding and shows consistent effectiveness and efficiency. Also, we carefully develop a system implementation for Sirius and show that Sirius achieves roughly 20% reduction in latency for 8B model on-chip and 35% reduction for 70B model offloading. We open-source our implementation of Sirius at https://github.com/Infini-AI-Lab/Sirius.git.
Paper Structure (32 sections, 3 equations, 10 figures, 20 tables, 1 algorithm)

This paper contains 32 sections, 3 equations, 10 figures, 20 tables, 1 algorithm.

Figures (10)

  • Figure 1: Contextual sparse models struggle at challenging text generation tests that require high-level reasoning and understanding, e.g. GSM8K. On these tasks, contextually sparse models lead to significant quality degradation. In (a), we contrast CS Llama-3-8B-Instruct on GSM8K (green) and CNN DailyMail (coral). (b) Contextual Sparsity Llama-3-70B-Instruct crashes at 50% global sparsity, making the smaller dense model Llama-3-8B-Instruct (green star) a significantly more efficient choice than the sparse 70B model. (c) Sparse model crashing at reasoning tasks has patterns, and ideally only correcting 11% unlikely tokens recovers the sparse model performance fully.
  • Figure 2: Overview of Sirius. Contextual Sparsity requires full model weights to be placed on the GPU memory. While the sparse model doesn't perform well on complex reasoning tasks, Sirius uses the Full Model to correct the Sparse model. The full model is called fairly infrequently. During the correction, the Full Model will rewrite the KV Cache, interleave with high-quality tokens to the sparse outputs, and then roll back only when the token is deemed extremely unlikely by the Full Model.
  • Figure 3: Speculative Decoding has limitation in efficiency when correcting sparse models.
  • Figure 4: Given the similar model parameters, the more well-trained the model is, the worse the degradation would be. (Compare the figures vertically between Llama-3 and Llama-2 family models).
  • Figure 5: We contrast between Contextual Sparsity on prompt understanding task and complex generation tasks that require reasoning. (a) Both CSparse and FSparse are robust on CNN/DailyMail for various sparsity; (b) and (c) Show that both CSparse and FSparse crash on GSM8K and HumanEval at the global sparsity that they are still robust in prompt understanding tasks.
  • ...and 5 more figures