Table of Contents
Fetching ...

Multi-Layer Transformers Gradient Can be Approximated in Almost Linear Time

Yingyu Liang, Zhizhou Sha, Zhenmei Shi, Zhao Song, Yufa Zhou

TL;DR

A novel fast approximation method can calculate the gradients in almost linear time where n is the input sequence length, while it maintains a polynomially small approximation error across the entire model.

Abstract

The computational complexity of the self-attention mechanism in popular transformer architectures poses significant challenges for training and inference, and becomes the bottleneck for long inputs. Is it possible to significantly reduce the quadratic time complexity of computing the gradients in multi-layer transformer models? This paper proves that a novel fast approximation method can calculate the gradients in almost linear time $n^{1+o(1)}$ where $n$ is the input sequence length, while it maintains a polynomially small approximation error $1 / \mathrm{poly}(n)$ across the entire model. Our theory holds for general loss functions and when the multi-layer transformer model contains many practical sub-modules, such as residual connection, casual mask, and multi-head attention. By improving the efficiency of gradient computation, we hope that this work will facilitate more effective training and deployment of long-context language models based on our theoretical results.

Multi-Layer Transformers Gradient Can be Approximated in Almost Linear Time

TL;DR

A novel fast approximation method can calculate the gradients in almost linear time where n is the input sequence length, while it maintains a polynomially small approximation error across the entire model.

Abstract

The computational complexity of the self-attention mechanism in popular transformer architectures poses significant challenges for training and inference, and becomes the bottleneck for long inputs. Is it possible to significantly reduce the quadratic time complexity of computing the gradients in multi-layer transformer models? This paper proves that a novel fast approximation method can calculate the gradients in almost linear time where is the input sequence length, while it maintains a polynomially small approximation error across the entire model. Our theory holds for general loss functions and when the multi-layer transformer model contains many practical sub-modules, such as residual connection, casual mask, and multi-head attention. By improving the efficiency of gradient computation, we hope that this work will facilitate more effective training and deployment of long-context language models based on our theoretical results.
Paper Structure (84 sections, 64 theorems, 196 equations, 2 algorithms)

This paper contains 84 sections, 64 theorems, 196 equations, 2 algorithms.

Key Result

Theorem 1.4

Let $n$ be the number of tokens, and $d$ the hidden dimension size. We assume $d = O(\log n)$ and each number in matrices can be written using $O(\log n)$ bits. Assume the number of layers $m=n^{o(1)}$. There exists an algorithm (Algorithm alg:multi_layer_grad_descent) that can compute the gradient

Theorems & Definitions (138)

  • Definition 1.1: Softmax
  • Definition 1.2: Self-attention module
  • Definition 1.3: Multi-layer transformer
  • Theorem 1.4: Main result, informal version of Theorem \ref{['thm:main_result']}
  • Definition 3.1: Loss function $L(X)$
  • Remark 3.2
  • Definition 3.3: Intermediate variables $T_i$
  • Lemma 3.4: Closed form of gradient components, informal version of Lemma \ref{['lem:grad_components_close_form']}
  • Theorem 4.1: Single-layer gradient approximation
  • proof
  • ...and 128 more