Table of Contents
Fetching ...

Efficient Backpropagation with Variance-Controlled Adaptive Sampling

Ziteng Wang, Jianfei Chen, Jun Zhu

TL;DR

This work tackles the high computational cost of backpropagation by introducing Variance-Controlled Adaptive Sampling (VCAS), which builds an unbiased approximated stochastic gradient through fine-grained data- and token-level sampling during backpropagation. By decomposing and actively controlling the additional variance introduced by sampling, VCAS preserves convergence and training dynamics while offering large reductions in BP FLOPs (up to 73.87%) and total training FLOPs (up to 49.58%). It combines activation-gradient sampling with leverages-score-based weight-gradient sampling and learns adaptive sample ratios via a variance-budget framework, including separate controls for activation and weight variance. Across vision and language tasks, VCAS achieves comparable final loss and accuracy to exact training with substantial speedups, and shows robust performance across hyperparameters and architectures (e.g., BERT, ViT, CNNs).

Abstract

Sampling-based algorithms, which eliminate ''unimportant'' computations during forward and/or back propagation (BP), offer potential solutions to accelerate neural network training. However, since sampling introduces approximations to training, such algorithms may not consistently maintain accuracy across various tasks. In this work, we introduce a variance-controlled adaptive sampling (VCAS) method designed to accelerate BP. VCAS computes an unbiased stochastic gradient with fine-grained layerwise importance sampling in data dimension for activation gradient calculation and leverage score sampling in token dimension for weight gradient calculation. To preserve accuracy, we control the additional variance by learning the sample ratio jointly with model parameters during training. We assessed VCAS on multiple fine-tuning and pre-training tasks in both vision and natural language domains. On all the tasks, VCAS can preserve the original training loss trajectory and validation accuracy with an up to 73.87% FLOPs reduction of BP and 49.58% FLOPs reduction of the whole training process. The implementation is available at https://github.com/thu-ml/VCAS .

Efficient Backpropagation with Variance-Controlled Adaptive Sampling

TL;DR

This work tackles the high computational cost of backpropagation by introducing Variance-Controlled Adaptive Sampling (VCAS), which builds an unbiased approximated stochastic gradient through fine-grained data- and token-level sampling during backpropagation. By decomposing and actively controlling the additional variance introduced by sampling, VCAS preserves convergence and training dynamics while offering large reductions in BP FLOPs (up to 73.87%) and total training FLOPs (up to 49.58%). It combines activation-gradient sampling with leverages-score-based weight-gradient sampling and learns adaptive sample ratios via a variance-budget framework, including separate controls for activation and weight variance. Across vision and language tasks, VCAS achieves comparable final loss and accuracy to exact training with substantial speedups, and shows robust performance across hyperparameters and architectures (e.g., BERT, ViT, CNNs).

Abstract

Sampling-based algorithms, which eliminate ''unimportant'' computations during forward and/or back propagation (BP), offer potential solutions to accelerate neural network training. However, since sampling introduces approximations to training, such algorithms may not consistently maintain accuracy across various tasks. In this work, we introduce a variance-controlled adaptive sampling (VCAS) method designed to accelerate BP. VCAS computes an unbiased stochastic gradient with fine-grained layerwise importance sampling in data dimension for activation gradient calculation and leverage score sampling in token dimension for weight gradient calculation. To preserve accuracy, we control the additional variance by learning the sample ratio jointly with model parameters during training. We assessed VCAS on multiple fine-tuning and pre-training tasks in both vision and natural language domains. On all the tasks, VCAS can preserve the original training loss trajectory and validation accuracy with an up to 73.87% FLOPs reduction of BP and 49.58% FLOPs reduction of the whole training process. The implementation is available at https://github.com/thu-ml/VCAS .
Paper Structure (29 sections, 19 equations, 11 figures, 9 tables, 1 algorithm)

This paper contains 29 sections, 19 equations, 11 figures, 9 tables, 1 algorithm.

Figures (11)

  • Figure 1: VCAS mirrors the convergence trajectory with exact training with FLOPs redution of 41.56%. Other methods like SB jiang2019accelerating and UB katharopoulos2018not fail with a similar FLOPs reduction.
  • Figure 2: The computing diagram of backpropagation with VCAS in every layer. We use light blue squares to represent small gradient entries and orange for large ones. White squares are discarded by sampling. The upper line calculates activation gradient and the lower for weight gradient. Please refer to Sec. \ref{['sec:sampling']} for notations.
  • Figure 3: Gradient distribution over different layer and iterations of BERT-base finetuning on SST2 (6315 iterations in total). The normalized gradient norm of each datum is shown in the heatmaps. Black solid lines are the 95% percentile. Data above the lines are likely to be dicarded by VCAS.
  • Figure 4: FLOPs reduction ratio of VCAS vs. sampling activation or weight solely with equal variance.
  • Figure 5: Gradient variance of different methods.
  • ...and 6 more figures