Table of Contents
Fetching ...

Linear Gradient Prediction with Control Variates

Kamil Ciosek, Nicolò Felicioni, Juan Elenter Litwin

TL;DR

Training neural networks is computationally intensive primarily due to backpropagation. The paper introduces Predicted Gradient Descent, which uses cheap approximate gradients combined with a debiasing control-variate to preserve unbiased convergence, and a gradient predictor derived from Neural Tangent Kernel insights to enable fast, low-rank gradient approximations. Theoretical analysis shows unbiasedness and variance control, with break-even conditions that relate gradient-alignment and compute savings; empirical results on a Vision Transformer trained on CIFAR-10 demonstrate improved wall-clock efficiency over full-gradient training under the same budget. The approach trades extra memory for reduced compute per iteration and demonstrates practical gains for large models under tight time constraints, with a principled framework for predictor recomputation and alignment tracking.

Abstract

We propose a new way of training neural networks, with the goal of reducing training cost. Our method uses approximate predicted gradients instead of the full gradients that require an expensive backward pass. We derive a control-variate-based technique that ensures our updates are unbiased estimates of the true gradient. Moreover, we propose a novel way to derive a predictor for the gradient inspired by the theory of the Neural Tangent Kernel. We empirically show the efficacy of the technique on a vision transformer classification task.

Linear Gradient Prediction with Control Variates

TL;DR

Training neural networks is computationally intensive primarily due to backpropagation. The paper introduces Predicted Gradient Descent, which uses cheap approximate gradients combined with a debiasing control-variate to preserve unbiased convergence, and a gradient predictor derived from Neural Tangent Kernel insights to enable fast, low-rank gradient approximations. Theoretical analysis shows unbiasedness and variance control, with break-even conditions that relate gradient-alignment and compute savings; empirical results on a Vision Transformer trained on CIFAR-10 demonstrate improved wall-clock efficiency over full-gradient training under the same budget. The approach trades extra memory for reduced compute per iteration and demonstrates practical gains for large models under tight time constraints, with a principled framework for predictor recomputation and alignment tracking.

Abstract

We propose a new way of training neural networks, with the goal of reducing training cost. Our method uses approximate predicted gradients instead of the full gradients that require an expensive backward pass. We derive a control-variate-based technique that ensures our updates are unbiased estimates of the true gradient. Moreover, we propose a novel way to derive a predictor for the gradient inspired by the theory of the Neural Tangent Kernel. We empirically show the efficacy of the technique on a vision transformer classification task.

Paper Structure

This paper contains 38 sections, 8 theorems, 53 equations, 2 figures, 2 algorithms.

Key Result

Lemma 0

Assuming the two micro-batches are i.i.d. draws and independent of each other, the estimator eq:G is unbiased: $\;\mathbb{E}[G]=\mu=\nabla F(\theta)$.

Figures (2)

  • Figure 1: CIFAR-10 Validation Accuracy vs Wall Clock Training Time. GPR stands for gradient prediction, which uses gradient prediction for 3/4 of the batch. The baseline uses full backward passes. The shaded area corresponds to standard errors for three random seeds per method.
  • Figure : Predicted Gradient Descent

Theorems & Definitions (12)

  • Lemma 0: Gradient unbiasedness
  • Proposition 0: Exact variance; dependence on cosine
  • Theorem 1: Break-even alignment
  • Theorem 2: Break-even regime switch and $f^\star$
  • Lemma 2: Gradient unbiasedness
  • proof
  • Proposition 2: Exact variance; dependence on cosine
  • proof
  • Theorem 2: Break-even alignment
  • proof
  • ...and 2 more