Table of Contents
Fetching ...

Compute-Optimal Quantization-Aware Training

Aleksandr Dremov, David Grangier, Angelos Katharopoulos, Awni Hannun

TL;DR

A loss scaling law is derived that predicts both optimal QAT ratios and final model performance across different QAT/FP compute allocation strategies and QAT bit widths and proposes a novel cooldown and QAT fusion approach that performs learning rate decay jointly with quantization-aware training.

Abstract

Quantization-aware training (QAT) is a leading technique for improving the accuracy of quantized neural networks. Previous work has shown that decomposing training into a full-precision (FP) phase followed by a QAT phase yields superior accuracy compared to QAT alone. However, the optimal allocation of compute between the FP and QAT phases remains unclear. We conduct extensive experiments with various compute budgets, QAT bit widths, and model sizes from 86.0M to 2.2B to investigate how different QAT durations impact final performance. We demonstrate that, contrary to previous findings, the loss-optimal ratio of QAT to FP training increases with the total amount of compute. Moreover, the optimal fraction can be accurately predicted for a wide range of model sizes and quantization widths using the tokens-per-parameter-byte statistic. From experimental data, we derive a loss scaling law that predicts both optimal QAT ratios and final model performance across different QAT/FP compute allocation strategies and QAT bit widths. We use the scaling law to make further predictions, which we verify experimentally, including which QAT bit width is optimal under a given memory constraint and how QAT accuracy with different bit widths compares to full-precision model accuracy. Additionally, we propose a novel cooldown and QAT fusion approach that performs learning rate decay jointly with quantization-aware training, eliminating redundant full-precision model updates and achieving significant compute savings. These findings provide practical insights into efficient QAT planning and enable the training of higher-quality quantized models with the same compute budget.

Compute-Optimal Quantization-Aware Training

TL;DR

A loss scaling law is derived that predicts both optimal QAT ratios and final model performance across different QAT/FP compute allocation strategies and QAT bit widths and proposes a novel cooldown and QAT fusion approach that performs learning rate decay jointly with quantization-aware training.

Abstract

Quantization-aware training (QAT) is a leading technique for improving the accuracy of quantized neural networks. Previous work has shown that decomposing training into a full-precision (FP) phase followed by a QAT phase yields superior accuracy compared to QAT alone. However, the optimal allocation of compute between the FP and QAT phases remains unclear. We conduct extensive experiments with various compute budgets, QAT bit widths, and model sizes from 86.0M to 2.2B to investigate how different QAT durations impact final performance. We demonstrate that, contrary to previous findings, the loss-optimal ratio of QAT to FP training increases with the total amount of compute. Moreover, the optimal fraction can be accurately predicted for a wide range of model sizes and quantization widths using the tokens-per-parameter-byte statistic. From experimental data, we derive a loss scaling law that predicts both optimal QAT ratios and final model performance across different QAT/FP compute allocation strategies and QAT bit widths. We use the scaling law to make further predictions, which we verify experimentally, including which QAT bit width is optimal under a given memory constraint and how QAT accuracy with different bit widths compares to full-precision model accuracy. Additionally, we propose a novel cooldown and QAT fusion approach that performs learning rate decay jointly with quantization-aware training, eliminating redundant full-precision model updates and achieving significant compute savings. These findings provide practical insights into efficient QAT planning and enable the training of higher-quality quantized models with the same compute budget.

Paper Structure

This paper contains 50 sections, 14 equations, 17 figures, 15 tables.

Figures (17)

  • Figure 1: On the left, experimental and predicted optimal QAT fractions as a function of tokens-per-parameter-byte are shown. Different colors represent models of varying sizes, while point sizes indicate final perplexity normalized across experiments with identical total token counts for each model size. Results span multiple QAT bit-widths, and optimal QAT fraction values for endpoints are displayed. The plot demonstrates that the optimal QAT fraction increases with the full training tokens-per-parameter-byte statistic. On the right, loss scaling law predictions for a 4 -bit QAT 396 M parameter model across varying QAT and FP training lengths. Both experimental and theoretical optima are shown. The optimal QAT fraction predicted by the loss scaling law for each total token count closely matches the experimentally observed fraction.
  • Figure 2: On the top, QAT optima for 396 M model plotted in token coordinates. Different optima for the same total token count and different QAT bit widths can be observed. On the bottom, QAT optima for 396 M model plotted in tokens-per-parameter-byte coordinates. With byte adjustment, different bit widths lie on the proposed fit line better.
  • Figure 3: Visualization of the fitted loss scaling law for a 759 M model, 1 -bit QAT, and different $D_\text{qat}, D_\text{fp}$. Orange lines represent constant $D_\text{total} = D_\text{qat} + D_\text{fp}$ levels, and stars represent loss minima for each such level. It is clearly seen that the loss structure yields an optimal QAT fraction for a specific $D_\text{total}$. The overall phenomenon is consistent with what was discussed in section \ref{['sec:opt-fraction-fit']}.
  • Figure 4: Comparison of sub-optimal QAT setup with fixed 10% QAT fraction and optimal QAT setup for 1B parameter model. Wasted token count is the number of tokens effectively wasted by not utilizing an optimal QAT fraction setup. That is, if the wasted token count is $n\%$, then the same loss can be achieved with $(100 - n)\%$ tokens and optimal QAT fraction. While results vary for different bit widths, the general relationship is similar, revealing high potential savings.
  • Figure 5: Difference in perplexity between FP loss scaling law and QAT loss scaling law for two model sizes. For QAT, the loss corresponding to the optimal QAT fraction is used. Values below 0 correspond to QAT performing better than the FP model. It is clearly observed that the ability of QAT to match FP loss is greatly influenced by model size and token count. In particular, larger models are able to tolerate lower QAT precision for higher total token count budgets. Additional plot information is present in appendix \ref{['app:qat-vs-fp']}.
  • ...and 12 more figures