Linear Gradient Prediction with Control Variates
Kamil Ciosek, Nicolò Felicioni, Juan Elenter Litwin
TL;DR
Training neural networks is computationally intensive primarily due to backpropagation. The paper introduces Predicted Gradient Descent, which uses cheap approximate gradients combined with a debiasing control-variate to preserve unbiased convergence, and a gradient predictor derived from Neural Tangent Kernel insights to enable fast, low-rank gradient approximations. Theoretical analysis shows unbiasedness and variance control, with break-even conditions that relate gradient-alignment and compute savings; empirical results on a Vision Transformer trained on CIFAR-10 demonstrate improved wall-clock efficiency over full-gradient training under the same budget. The approach trades extra memory for reduced compute per iteration and demonstrates practical gains for large models under tight time constraints, with a principled framework for predictor recomputation and alignment tracking.
Abstract
We propose a new way of training neural networks, with the goal of reducing training cost. Our method uses approximate predicted gradients instead of the full gradients that require an expensive backward pass. We derive a control-variate-based technique that ensures our updates are unbiased estimates of the true gradient. Moreover, we propose a novel way to derive a predictor for the gradient inspired by the theory of the Neural Tangent Kernel. We empirically show the efficacy of the technique on a vision transformer classification task.
