Table of Contents
Fetching ...

High-Accuracy Low-Precision Training

Christopher De Sa, Megan Leszczynski, Jian Zhang, Alana Marzoev, Christopher R. Aberger, Kunle Olukotun, Christopher Ré

TL;DR

The paper tackles the problem of performing high-accuracy training with low-precision arithmetic. It introduces HALP, a method that combines stochastic variance-reduced gradients with a novel bit-centering technique to suppress quantization noise, achieving linear convergence to high accuracy using a fixed bit-width. The authors provide theoretical guarantees and demonstrate practical CPU speedups (up to 4x) on tasks ranging from CNN/LSTM deep learning to multi-class logistic regression, validated via TensorQuant simulations. The work highlights conditioning as a key factor and offers insights for hardware design to support efficient, high-precision training on low-precision units.

Abstract

Low-precision computation is often used to lower the time and energy cost of machine learning, and recently hardware accelerators have been developed to support it. Still, it has been used primarily for inference - not training. Previous low-precision training algorithms suffered from a fundamental tradeoff: as the number of bits of precision is lowered, quantization noise is added to the model, which limits statistical accuracy. To address this issue, we describe a simple low-precision stochastic gradient descent variant called HALP. HALP converges at the same theoretical rate as full-precision algorithms despite the noise introduced by using low precision throughout execution. The key idea is to use SVRG to reduce gradient variance, and to combine this with a novel technique called bit centering to reduce quantization error. We show that on the CPU, HALP can run up to $4 \times$ faster than full-precision SVRG and can match its convergence trajectory. We implemented HALP in TensorQuant, and show that it exceeds the validation performance of plain low-precision SGD on two deep learning tasks.

High-Accuracy Low-Precision Training

TL;DR

The paper tackles the problem of performing high-accuracy training with low-precision arithmetic. It introduces HALP, a method that combines stochastic variance-reduced gradients with a novel bit-centering technique to suppress quantization noise, achieving linear convergence to high accuracy using a fixed bit-width. The authors provide theoretical guarantees and demonstrate practical CPU speedups (up to 4x) on tasks ranging from CNN/LSTM deep learning to multi-class logistic regression, validated via TensorQuant simulations. The work highlights conditioning as a key factor and offers insights for hardware design to support efficient, high-precision training on low-precision units.

Abstract

Low-precision computation is often used to lower the time and energy cost of machine learning, and recently hardware accelerators have been developed to support it. Still, it has been used primarily for inference - not training. Previous low-precision training algorithms suffered from a fundamental tradeoff: as the number of bits of precision is lowered, quantization noise is added to the model, which limits statistical accuracy. To address this issue, we describe a simple low-precision stochastic gradient descent variant called HALP. HALP converges at the same theoretical rate as full-precision algorithms despite the noise introduced by using low precision throughout execution. The key idea is to use SVRG to reduce gradient variance, and to combine this with a novel technique called bit centering to reduce quantization error. We show that on the CPU, HALP can run up to faster than full-precision SVRG and can match its convergence trajectory. We implemented HALP in TensorQuant, and show that it exceeds the validation performance of plain low-precision SGD on two deep learning tasks.

Paper Structure

This paper contains 17 sections, 4 theorems, 63 equations, 7 figures, 4 tables, 4 algorithms.

Key Result

Theorem 1

Suppose that we run LP-SVRG (Algorithm algLPSVRG) under the above conditions, using option II for the epoch update. For any constant $0 < \gamma < 1$ (a parameter which controls how often we take full gradients), if we set our step size and epoch lengths to be then the outer iterates of LP-SVRG will converge to an accuracy limit at a linear rate

Figures (7)

  • Figure 1: Linear regression on a synthetic dataset with 100 features and 1000 examples generated by scikit-learn's make_regression generator scikit-learn. The epoch length was set to $T = 2000$, twice the number of examples, and the learning rates $\alpha$ and scale factors $\delta$ were chosen using grid search for all algorithms. For all versions of SGD, $\alpha = 2.5 \times 10^{-6}$, and for all versions of SVRG, $\alpha = 5 \times 10^{-3}$. All LP 8-bit algorithms use $\delta = 0.7$ and all LP 16-bit algorithms use $\delta = 0.003$. All HALP algorithms use $\alpha = 5 \times 10^{-3}$ and $\mu = 3$.
  • Figure 2: A diagram of the bit scaling operation in HALP. As the algorithm converges, we are able to bound the solution within a smaller and smaller ball. Periodically, we re-center the points that our low-precision model can represent so they are centered on this ball, and we re-scale the points so that more of them are inside the ball. This decreases quantization error as we converge.
  • Figure 3: Training loss and validation perplexity on LSTM for character level language modeling with TinyShakespeare dataset and for ResNet for image recognition with CIFAR10 dataset. Training loss for is smoothed for visualization purposes. The CIFAR10 accuracy is monotonic because we report the best value up to each specific number of iterations, which is standard for reporting validation accuracy.
  • Figure 4: Convergence of 8-bit low-precision algorithms on multi-class logistic regression training.
  • Figure 5: Additional linear regression experiment on synthetic dataset with an SVD(U,S,V) where U and V are random orthogonal matrices and the singular values S are such that the condition number of the resulting problem is $\kappa$. The gradient norm is measured after 50 epochs.
  • ...and 2 more figures

Theorems & Definitions (8)

  • Theorem 1
  • Theorem 2
  • Lemma 1
  • proof : Proof of Lemma \ref{['lemmaQuantization']}
  • Lemma 2
  • proof : Proof of Lemma \ref{['lemmaSVRG8']}
  • proof : Proof of Theorem \ref{['thmLPSVRG']}
  • proof : Proof of Theorem \ref{['thmHALP']}