Improving the Straight-Through Estimator with Zeroth-Order Information
Ningfeng Yang, Tor M. Aamodt
TL;DR
This paper targets the challenge of training quantized neural networks by marrying the computational efficiency of the Straight-Through Estimator with the theoretical grounding of zeroth-order methods. It introduces FOGZO, a first-order-guided zeroth-order gradient estimator that mixes a biased STE signal with an unbiased ZO gradient, controlled by a decaying mixing parameter. The approach yields improved accuracy over STE on multiple architectures (e.g., DeiT, ResNet) and tasks (classification, language modeling with LLaMA) while reducing the training-time cost relative to fully unbiased ZO methods. The work also shows compatibility with state-of-the-art QAT methods and offers principled guidance for hyperparameter selection via smoothing mappings, along with practical strategies to reduce training time overhead. Overall, FOGZO provides a scalable, more accurate alternative for quantized pre-training and fine-tuning by leveraging biased gradient insights to guide zeroth-order estimation.
Abstract
We study the problem of training neural networks with quantized parameters. Learning low-precision quantized parameters by enabling computation of gradients via the Straight-Through Estimator (STE) can be challenging. While the STE enables back-propagation, which is a first-order method, recent works have explored the use of zeroth-order (ZO) gradient descent for fine-tuning. We note that the STE provides high-quality biased gradients, and ZO gradients are unbiased but can be expensive. We thus propose First-Order-Guided Zeroth-Order Gradient Descent (FOGZO) that reduces STE bias while reducing computations relative to ZO methods. Empirically, we show FOGZO improves the tradeoff between quality and training time in Quantization-Aware Pre-Training. Specifically, versus STE at the same number of iterations, we show a 1-8\% accuracy improvement for DeiT Tiny/Small, 1-2\% accuracy improvement on ResNet 18/50, and 1-22 perplexity point improvement for LLaMA models with up to 0.3 billion parameters. For the same loss, FOGZO yields a 796$\times$ reduction in computation versus n-SPSA for a 2-layer MLP on MNIST. Code is available at https://github.com/1733116199/fogzo.
