Table of Contents
Fetching ...

Winner-Take-All Column Row Sampling for Memory Efficient Adaptation of Language Model

Zirui Liu, Guanchu Wang, Shaochen Zhong, Zhaozhuo Xu, Daochen Zha, Ruixiang Tang, Zhimeng Jiang, Kaixiong Zhou, Vipin Chaudhary, Shuai Xu, Xia Hu

TL;DR

This work proposes a new family of unbiased estimators called WTA-CRS, for matrix production with reduced variance, which only requires storing the sub-sampled activations for calculating the gradient in a stochastic manner.

Abstract

With the rapid growth in model size, fine-tuning the large pre-trained language model has become increasingly difficult due to its extensive memory usage. Previous works usually focus on reducing the number of trainable parameters in the network. While the model parameters do contribute to memory usage, the primary memory bottleneck during training arises from storing feature maps, also known as activations, as they are crucial for gradient calculation. Notably, neural networks are usually trained using stochastic gradient descent. We argue that in stochastic optimization, models can handle noisy gradients as long as the gradient estimator is unbiased with reasonable variance. Following this motivation, we propose a new family of unbiased estimators called WTA-CRS, for matrix production with reduced variance, which only requires storing the sub-sampled activations for calculating the gradient. Our work provides both theoretical and experimental evidence that, in the context of tuning transformers, our proposed estimators exhibit lower variance compared to existing ones. By replacing the linear operation with our approximated one in transformers, we can achieve up to 2.7$\times$ peak memory reduction with almost no accuracy drop and enables up to $6.4\times$ larger batch size. Under the same hardware, WTA-CRS enables better down-streaming task performance by applying larger models and/or faster training speed with larger batch sizes.

Winner-Take-All Column Row Sampling for Memory Efficient Adaptation of Language Model

TL;DR

This work proposes a new family of unbiased estimators called WTA-CRS, for matrix production with reduced variance, which only requires storing the sub-sampled activations for calculating the gradient in a stochastic manner.

Abstract

With the rapid growth in model size, fine-tuning the large pre-trained language model has become increasingly difficult due to its extensive memory usage. Previous works usually focus on reducing the number of trainable parameters in the network. While the model parameters do contribute to memory usage, the primary memory bottleneck during training arises from storing feature maps, also known as activations, as they are crucial for gradient calculation. Notably, neural networks are usually trained using stochastic gradient descent. We argue that in stochastic optimization, models can handle noisy gradients as long as the gradient estimator is unbiased with reasonable variance. Following this motivation, we propose a new family of unbiased estimators called WTA-CRS, for matrix production with reduced variance, which only requires storing the sub-sampled activations for calculating the gradient. Our work provides both theoretical and experimental evidence that, in the context of tuning transformers, our proposed estimators exhibit lower variance compared to existing ones. By replacing the linear operation with our approximated one in transformers, we can achieve up to 2.7 peak memory reduction with almost no accuracy drop and enables up to larger batch size. Under the same hardware, WTA-CRS enables better down-streaming task performance by applying larger models and/or faster training speed with larger batch sizes.
Paper Structure (34 sections, 4 theorems, 32 equations, 13 figures, 7 tables, 3 algorithms)

This paper contains 34 sections, 4 theorems, 32 equations, 13 figures, 7 tables, 3 algorithms.

Key Result

Theorem 1

The estimator defined in Equation (eq: sum_and_sample) is an unbiased estimator for matrix production ${\bm{X}}{\bm{Y}}$, i.e, $\mathbb{E}_{j\sim \mathcal{P}^{\mathcal{D}\backslash\mathcal{C}}} [\sum_{c\in\mathcal{C}} f(c)p_c + (1 - \sum_{c\in\mathcal{C}}p_c)f(j)]={\bm{X}}{\bm{Y}}$.

Figures (13)

  • Figure 1: Accuracy-memory trade-off of WTA-CRS and other memory-efficient tuning methods. Unless specially stated, we use the T5-Large in the figure.
  • Figure 2: The GPU memory usage breakdown for fine-tuning T5 glue, where the batch size $B$ is 64 and sequential length $S$ is 128 or 256.
  • Figure 3: The probability mass $\sum_{c\in\mathcal{C}}p_c$ versus $\frac{|\mathcal{C}|}{k}$ in Equation (\ref{['eq: var_thresh']}) at $k=0.3|\mathcal{D}|$. Here we visualize the column-row index distribution of query/key/value projection layer in the T5-base model, which is fine-tuned on RTE dataset. More similar results can be found in Appendix \ref{['app: more_res_theo2']}.
  • Figure 4: The diagram of a single Transformer block. The shape of activations is annotated, where $B, S, D_{\text{model}}$, $N_{\text{head}}$, and $D_{\text{head}}$ are the batch size, sequence length, hidden size, number of attention heads, and head dimension, respectively. WTA-CRS can be applied to the operators in green; the activation maps of operators in blue can be losslessly compressed; and those in gray are not compressed in this paper. The idea of this figure is inspired by andoorveedu2022tempo.
  • Figure 5: The illustration of how to deploy WTA-CRS to linear layers. We only replace GEMM in Equation (\ref{['eq: bwd_w']}) with its approximated version using WTA-CRS . The pseudocode is given in Appendix \ref{['app: implementation']} Algorithm \ref{['algo: approx_q_linear']}.
  • ...and 8 more figures

Theorems & Definitions (7)

  • Theorem 1: Proof in Appendix \ref{['app: theory_unbias']}
  • Theorem 2: Proof in Appendix \ref{['app: theory_var']}
  • proof
  • Theorem 2: Proof in Appendix \ref{['app: theory_unbias']}
  • proof
  • Theorem 2: Proof in Appendix \ref{['app: theory_var']}
  • proof