Table of Contents
Fetching ...

Scaling Laws for Precision

Tanishq Kumar, Zachary Ankner, Benjamin F. Spector, Blake Bordelon, Niklas Muennighoff, Mansheej Paul, Cengiz Pehlevan, Christopher Ré, Aditi Raghunathan

TL;DR

The paper introduces precision-aware scaling laws for language model pretraining and inference, integrating low-precision training and post-training quantization into a unified loss framework. By modeling an effective parameter count $N_{eff}$ and a PTQ degradation term $\delta_{PTQ}$, the authors show how precision interacts with data and parameter counts to shape final performance, revealing that compute-optimal training often occurs around 7–8 bits and that overtraining can make PTQ harmful. They demonstrate multiplicative, partly independent effects for weight, activation, and KV-cache quantization, derive allocation strategies under various compute constraints, and unify training and inference degradations into a single functional form validated across 465 pretraining runs up to 1.7B parameters and 26B tokens. The findings have practical implications for hardware-aware training, suggesting when to train larger models in lower precision and how post-training quantization costs scale with data—informing future design of compute budgets and quantization strategies.

Abstract

Low precision training and inference affect both the quality and cost of language models, but current scaling laws do not account for this. In this work, we devise "precision-aware" scaling laws for both training and inference. We propose that training in lower precision reduces the model's "effective parameter count," allowing us to predict the additional loss incurred from training in low precision and post-train quantization. For inference, we find that the degradation introduced by post-training quantization increases as models are trained on more data, eventually making additional pretraining data actively harmful. For training, our scaling laws allow us to predict the loss of a model with different parts in different precisions, and suggest that training larger models in lower precision may be compute optimal. We unify the scaling laws for post and pretraining quantization to arrive at a single functional form that predicts degradation from training and inference in varied precisions. We fit on over 465 pretraining runs and validate our predictions on model sizes up to 1.7B parameters trained on up to 26B tokens.

Scaling Laws for Precision

TL;DR

The paper introduces precision-aware scaling laws for language model pretraining and inference, integrating low-precision training and post-training quantization into a unified loss framework. By modeling an effective parameter count and a PTQ degradation term , the authors show how precision interacts with data and parameter counts to shape final performance, revealing that compute-optimal training often occurs around 7–8 bits and that overtraining can make PTQ harmful. They demonstrate multiplicative, partly independent effects for weight, activation, and KV-cache quantization, derive allocation strategies under various compute constraints, and unify training and inference degradations into a single functional form validated across 465 pretraining runs up to 1.7B parameters and 26B tokens. The findings have practical implications for hardware-aware training, suggesting when to train larger models in lower precision and how post-training quantization costs scale with data—informing future design of compute budgets and quantization strategies.

Abstract

Low precision training and inference affect both the quality and cost of language models, but current scaling laws do not account for this. In this work, we devise "precision-aware" scaling laws for both training and inference. We propose that training in lower precision reduces the model's "effective parameter count," allowing us to predict the additional loss incurred from training in low precision and post-train quantization. For inference, we find that the degradation introduced by post-training quantization increases as models are trained on more data, eventually making additional pretraining data actively harmful. For training, our scaling laws allow us to predict the loss of a model with different parts in different precisions, and suggest that training larger models in lower precision may be compute optimal. We unify the scaling laws for post and pretraining quantization to arrive at a single functional form that predicts degradation from training and inference in varied precisions. We fit on over 465 pretraining runs and validate our predictions on model sizes up to 1.7B parameters trained on up to 26B tokens.

Paper Structure

This paper contains 43 sections, 37 equations, 21 figures, 2 tables.

Figures (21)

  • Figure 1: Schematic of key findings. (Left) Training a fixed model size to various data budgets in BF16 and quantizing weights at the end. We find that degradation due to post-train quantization increases with tokens seen during pretraining, so that eventually additional pretraining data can be harmful. (Right) Our scaling suggests training larger models in lower precision can be compute-optimal according to the cost model in Section \ref{['section:implications-pretraining']}. Weights, activations, attention quantized, all models trained on the same data budget, details in Appendix \ref{['appdx:main-fig']}.
  • Figure 2: Loss degradation from PTQ increases with data. Top row is loss after PTQ, bottom row is loss degradation compared to end of training, before PTQ. The top row is thus the gray line in each plot plus the corresponding value in the bottom row. We can see that degradation grows with data, bottom row is fitted with Equation \ref{['eqn:degrade']}. For $D/N$ sufficiently large (left), loss can increase in data. Even at lower $D/N$, where post-quant loss continues to decrease with data, the value of data is reduced compare to the baseline. $R^2 = 0.97$ over all fitted points (bottom row).
  • Figure 3: (Left) $N_\text{eff}/N$ from our final scaling law. Our fit of $N_\text{eff}(N, P_\text{w})$ in this section is the first step towards this (blue). Empirical (center) and predicted (right) IsoLoss contours illustrating the precision-parameter tradeoff. Y-axis is weight precision during quantized training. All runs plotted trained on $D = 13$B tokens. Predictions from a fitted version of Equation \ref{['eqn:our-chinchilla']}, darker lines correspond to lower loss.
  • Figure 4: Predicting final validation losses $L(N, D, P_\text{w})$ for various $N, D, P_\text{w}$ to test our proposed functional form. Points are experimental values, lines are predictions of a single parametric fit of the form in Equation \ref{['eqn:our-chinchilla']}. We train only two model sizes at 26B due to compute constraints.
  • Figure 5: (Left) Predicted loss based on fitted values with Equation \ref{['eqn: full']}. (center) Fitting $\gamma$ parameters jointly on sweeps with combinations of precisions vs (right) fitting them on "marginal" sweeps where only one model part is in low precision at a time. Outliers are those at extremely low precision whose training runs are sometimes unstable.
  • ...and 16 more figures