Table of Contents
Fetching ...

Direct Quantized Training of Language Models with Stochastic Rounding

Kaiyan Zhao, Tsuguchika Tabaru, Kenichi Kobayashi, Takumi Honda, Masafumi Yamazaki, Yoshimasa Tsuruoka

TL;DR

Direct Quantized Training (DQT) tackles the memory bottleneck of training quantized LLMs by updating low-precision weights directly, avoiding the need to store high-precision parameters. Using stochastic rounding, DQT preserves critical update signals within an $n$-bit INT$n$ representation throughout training, enabling convergence even with ternary weights and robust performance under memory constraints. Empirical results on LLaMA-family models (130M, 320M, 1B) show that 8-bit DQT approaches or matches BitNet in many settings, while ternary DQT can converge with substantially reduced memory footprints; DQT also supports inference with ternary weights. Overall, DQT provides a practical, memory-efficient QAT alternative with strong deployment flexibility and competitive accuracy across diverse tasks and settings.

Abstract

Although recent quantized Large Language Models (LLMs), such as BitNet, have paved the way for significant reduction in memory usage during deployment with binary or ternary weights, training these models still demands substantial memory footprints. This is partly because high-precision (i.e., unquantized) weights required for straight-through estimation must be maintained throughout the whole training process. To address this, we explore directly updating the quantized low-precision weights without relying on straight-through estimation during backpropagation, aiming to save memory usage during training. Specifically, we employ a stochastic rounding technique to minimize the information loss caused by the use of low-bit weights throughout training. Experimental results on our LLaMA-structured models of various sizes indicate that (1) training with only low-precision weights is feasible even when they are constrained to ternary values; (2) extending the bit width to 8 bits achieves performance on par with BitNet b1.58; (3) our models remain robust to precision scaling and memory reduction, showing minimal performance degradation when moving from FP32 to lower-memory environments (BF16/FP8); and (4) our models also support inference using ternary weights, showcasing their flexibility in deployment.

Direct Quantized Training of Language Models with Stochastic Rounding

TL;DR

Direct Quantized Training (DQT) tackles the memory bottleneck of training quantized LLMs by updating low-precision weights directly, avoiding the need to store high-precision parameters. Using stochastic rounding, DQT preserves critical update signals within an -bit INT representation throughout training, enabling convergence even with ternary weights and robust performance under memory constraints. Empirical results on LLaMA-family models (130M, 320M, 1B) show that 8-bit DQT approaches or matches BitNet in many settings, while ternary DQT can converge with substantially reduced memory footprints; DQT also supports inference with ternary weights. Overall, DQT provides a practical, memory-efficient QAT alternative with strong deployment flexibility and competitive accuracy across diverse tasks and settings.

Abstract

Although recent quantized Large Language Models (LLMs), such as BitNet, have paved the way for significant reduction in memory usage during deployment with binary or ternary weights, training these models still demands substantial memory footprints. This is partly because high-precision (i.e., unquantized) weights required for straight-through estimation must be maintained throughout the whole training process. To address this, we explore directly updating the quantized low-precision weights without relying on straight-through estimation during backpropagation, aiming to save memory usage during training. Specifically, we employ a stochastic rounding technique to minimize the information loss caused by the use of low-bit weights throughout training. Experimental results on our LLaMA-structured models of various sizes indicate that (1) training with only low-precision weights is feasible even when they are constrained to ternary values; (2) extending the bit width to 8 bits achieves performance on par with BitNet b1.58; (3) our models remain robust to precision scaling and memory reduction, showing minimal performance degradation when moving from FP32 to lower-memory environments (BF16/FP8); and (4) our models also support inference using ternary weights, showcasing their flexibility in deployment.

Paper Structure

This paper contains 23 sections, 5 equations, 6 figures, 1 table.

Figures (6)

  • Figure 1: Comparison of the training process for BitNet and our modified one. Upper: The training process for BitNet, where the original high precision weights are updated with the straight-through estimator in backward process. Lower: We directly update the low precision weights with stochastic rounding, eliminating the need to quantize the weight matrices in each training step and keeping weight matrices always at low-bit. We provide an 8-bit example in Supplementary Material, Figure 3.
  • Figure 2: Comparison of our DQT and other baselines across different model sizes and training datasets. The horizontal axis represents the training steps while the vertical axis represents the training loss. As model size increases, the performance of our DQT models, especially DQT 8 bit, become more comparable to and even better than the reproduced BitNet b1.58.
  • Figure 3: GPU memory usage versus loss on development set. While BitNet suffers significant performance degradation in low-precision formats, DQT demonstrates strong robustness with minimal loss increase.
  • Figure 4: Comparison of bit widths in DQT. Higher $n$-bit results in better performance.
  • Figure 5: Left: Comparison between DQT 1.58 bits and a variant using absmax quantization for weight updates under the same learning rate. Right: Percentage of updated weights after each training step.
  • ...and 1 more figures