Table of Contents
Fetching ...

BWTA: Accurate and Efficient Binarized Transformer by Algorithm-Hardware Co-design

Yifu Ding, Xianglong Liu, Shenghao Jin, Jinyang Guo, Jiwen Lu

Abstract

Ultra low-bit quantization brings substantial efficiency for Transformer-based models, but the accuracy degradation and limited GPU support hinder its wide usage. In this paper, we analyze zero-point distortion in binarization and propose a Binary Weights & Ternary Activations (BWTA) quantization scheme, which projects tiny values to zero and preserves the accuracy of extremely low-bit models. For training, we propose Smooth Multi-Stage Quantization, combining a Levelwise Degradation Strategy and a Magnitude-Alignment Projection Factor to enable stable and fast convergence. For inference, we develop a BWTA MatMul CUDA kernel with instruction-level parallel bit-packing and comprehensive binary/ternary MatMul implementations for both linear and attention operators, allowing seamless integration across Transformer architectures. Experiments show that BWTA approaches full-precision performance for BERT, with an average 3.5% drop on GLUE and less than 2% drop on five tasks, and achieves comparable perplexity and accuracy for LLMs. In efficiency, it delivers 16 to 24 times kernel-level speedup over FP16 on NVIDIA GPUs, and 216 to 330 tokens/s end-to-end prefill speedup with lower memory footprint on LLMs. As an algorithm-hardware co-design, BWTA demonstrates practical, low-latency ultra-low-bit inference without sacrificing model quality.

BWTA: Accurate and Efficient Binarized Transformer by Algorithm-Hardware Co-design

Abstract

Ultra low-bit quantization brings substantial efficiency for Transformer-based models, but the accuracy degradation and limited GPU support hinder its wide usage. In this paper, we analyze zero-point distortion in binarization and propose a Binary Weights & Ternary Activations (BWTA) quantization scheme, which projects tiny values to zero and preserves the accuracy of extremely low-bit models. For training, we propose Smooth Multi-Stage Quantization, combining a Levelwise Degradation Strategy and a Magnitude-Alignment Projection Factor to enable stable and fast convergence. For inference, we develop a BWTA MatMul CUDA kernel with instruction-level parallel bit-packing and comprehensive binary/ternary MatMul implementations for both linear and attention operators, allowing seamless integration across Transformer architectures. Experiments show that BWTA approaches full-precision performance for BERT, with an average 3.5% drop on GLUE and less than 2% drop on five tasks, and achieves comparable perplexity and accuracy for LLMs. In efficiency, it delivers 16 to 24 times kernel-level speedup over FP16 on NVIDIA GPUs, and 216 to 330 tokens/s end-to-end prefill speedup with lower memory footprint on LLMs. As an algorithm-hardware co-design, BWTA demonstrates practical, low-latency ultra-low-bit inference without sacrificing model quality.

Paper Structure

This paper contains 43 sections, 9 equations, 14 figures, 12 tables.

Figures (14)

  • Figure 1: The overview of the Binary Weight & Ternary Activation (BWTA) framework. The left is the training algorithm named Smooth Multi-stage Quantization, designed for ultra-low bit Transformers with ternary activation, including the (1) binary/ternary definitions to build a BWTA model, (2) smooth strategies for stable convergence, and (3) distillation for fast training. The right is the full stack GPU support for the custom BWTA MatMul Kernel, including the (1) binarized weight for storage reduction, (2) instruction-level parallel bitpack for runtime quantization, and (3) low-bit BWTA MatMul kernels for fast computation.
  • Figure 2: Histograms for binary/ternary activation in Self-Attention structure before and after quantized. (a) shows the value matrix $V_\mathrm{bitwise}$ and the activation $A_\mathrm{bitwise}$ in the linear layer following the multiplication of attention score and value trained by bitwise degradation strategy, while (b) shows the distribution of $V_{\mathrm{levelwise}}$ and $A_\mathrm{levelwise}$ by levelwise strategy.
  • Figure 3: Illustrations for two practices and the issues of projecting $0_{fp}$ to integer space.
  • Figure 4: (a) The illustration of the bit/levelwise multi-stage quantization strategies. (b) The grid points shift with projection factors (blue arrows), and convergent again in the next stage with fewer integers (gray arrows).
  • Figure 5: Data distribution in threads for weight, activation and mma results using m8n8k128 data layout. Binary weight is packed and stored before inference, and we pick $8\times 128$ elements for each SIMT, with four copies stored in 128 threads. Activation is packed into ternary during runtime with four slices each SIMT without repeated copies.
  • ...and 9 more figures