Table of Contents
Fetching ...

CATS: Contextually-Aware Thresholding for Sparsity in Large Language Models

Donghyun Lee, Je-Yong Lee, Genghan Zhang, Mo Tiwari, Azalia Mirhoseini

TL;DR

Large language models incur high inference costs, motivating sparsity-driven acceleration. CATS introduces a contextually aware thresholding activation for Gated-MLP blocks, enabling controllable activation sparsity without significant performance loss and translating sparsity into wall-clock speedups with a custom GPU kernel. Across Mistral-7B and Llama2-7B, CATS achieves near-base downstream task performance at 50–70% sparsity, faster convergence with LoRA, and substantial latency/throughput improvements. These results demonstrate a practical path to reducing inference costs in modern LLMs without extensive fine-tuning.

Abstract

Large Language Models (LLMs) have dramatically advanced AI applications, yet their deployment remains challenging due to their immense inference costs. Recent studies ameliorate the computational costs of LLMs by increasing their activation sparsity but suffer from significant performance degradation on downstream tasks. In this work, we introduce a new framework for sparsifying the activations of base LLMs and reducing inference costs, dubbed Contextually Aware Thresholding for Sparsity (CATS). CATS is relatively simple, easy to implement, and highly effective. At the heart of our framework is a new non-linear activation function. We demonstrate that CATS can be applied to various base models, including Mistral-7B and Llama2-7B, and outperforms existing sparsification techniques in downstream task performance. More precisely, CATS-based models often achieve downstream task performance within 1-2% of their base models without any fine-tuning and even at activation sparsity levels of 50%. Furthermore, CATS-based models converge faster and display better task performance than competing techniques when fine-tuning is applied. Finally, we develop a custom GPU kernel for efficient implementation of CATS that translates the activation of sparsity of CATS to real wall-clock time speedups. Our custom kernel implementation of CATS results in a ~15% improvement in wall-clock inference latency of token generation on both Llama-7B and Mistral-7B.

CATS: Contextually-Aware Thresholding for Sparsity in Large Language Models

TL;DR

Large language models incur high inference costs, motivating sparsity-driven acceleration. CATS introduces a contextually aware thresholding activation for Gated-MLP blocks, enabling controllable activation sparsity without significant performance loss and translating sparsity into wall-clock speedups with a custom GPU kernel. Across Mistral-7B and Llama2-7B, CATS achieves near-base downstream task performance at 50–70% sparsity, faster convergence with LoRA, and substantial latency/throughput improvements. These results demonstrate a practical path to reducing inference costs in modern LLMs without extensive fine-tuning.

Abstract

Large Language Models (LLMs) have dramatically advanced AI applications, yet their deployment remains challenging due to their immense inference costs. Recent studies ameliorate the computational costs of LLMs by increasing their activation sparsity but suffer from significant performance degradation on downstream tasks. In this work, we introduce a new framework for sparsifying the activations of base LLMs and reducing inference costs, dubbed Contextually Aware Thresholding for Sparsity (CATS). CATS is relatively simple, easy to implement, and highly effective. At the heart of our framework is a new non-linear activation function. We demonstrate that CATS can be applied to various base models, including Mistral-7B and Llama2-7B, and outperforms existing sparsification techniques in downstream task performance. More precisely, CATS-based models often achieve downstream task performance within 1-2% of their base models without any fine-tuning and even at activation sparsity levels of 50%. Furthermore, CATS-based models converge faster and display better task performance than competing techniques when fine-tuning is applied. Finally, we develop a custom GPU kernel for efficient implementation of CATS that translates the activation of sparsity of CATS to real wall-clock time speedups. Our custom kernel implementation of CATS results in a ~15% improvement in wall-clock inference latency of token generation on both Llama-7B and Mistral-7B.
Paper Structure (19 sections, 8 equations, 8 figures, 5 tables, 3 algorithms)

This paper contains 19 sections, 8 equations, 8 figures, 5 tables, 3 algorithms.

Figures (8)

  • Figure 1: Histograms of post-MLP activations of different layers in different models. Subfigures (a), (b), and (c) correspond to Layers 0, 15, and 31 in Llama2-7B, respectively. Subfigures (d), (e), and (f) correspond to Layers 0, 15, and 31 in Mistral 7B, respectively. The absolute threshold indicates 50% sparsity, where values smaller than the threshold are considered negligible in our technique and thus zeroed out.
  • Figure 2: Downstream task performance of the base model, CATS models with different sparsity levels, and ReLUfication across varying numbers of fine-tuning steps on the RefinedWeb dataset applied to Mistral-7B (left) and Llama2-7B (right). The CATS models exhibit faster convergence and greater fine-tuning efficiency than the ReLUfication variants. Furthermore, CATS-50% and CATS-70% demonstrate comparable performance to the base models without any fine-tuning (0 fine-tuning steps).
  • Figure 3: Latency of the original Mistral-7B MLP block (left, "Dense"), Llama-7B MLP block (right, "Dense"), and their CATS-based variants at different sparsity levels, compared to "Optimal." Our custom GPU kernel improves the latency of the CATS-based variants and achieves performance close to "Optimal" for most reasonable sparsity levels.
  • Figure 4: Throughput of Mistral-7B (left, "Dense") and Llama2-7B (right, "Dense") and CATS-50% with the custom GPU kernel. CATS-50% demonstrates significantly higher throughput.
  • Figure 5: CATS-based models still exhibit sparsity after general fine-tuning on the RefinedWeb dataset.
  • ...and 3 more figures