Table of Contents
Fetching ...

The Fine-Grained Complexity of Gradient Computation for Training Large Language Models

Josh Alman, Zhao Song

TL;DR

This work establishes a precise fine-grained complexity threshold for gradient computation in training large language models by analyzing the Approximate Attention Loss Gradient Computation problem $\mathsf{AAttLGC}(n,d,\epsilon)$. It proves that the same threshold observed for forward attention steps under $SETH$ applies to the backward gradient, with near-linear time algorithms possible when parameter magnitudes are small ($B = o(\sqrt{\log n})$) and strong hardness when large ($B = \omega(\sqrt{\log n})$). The upper bound leverages a tensor-trick and low-rank decompositions to reduce the gradient computation to $n^{1+o(1)}$ time, while the lower bound follows from a reduction from Attention Computation. Collectively, the results illuminate the fundamental limits and guide practical optimization strategies for efficient LLM training by controlling the magnitude of matrix entries in attention modules.

Abstract

Large language models (LLMs) have made fundamental contributions over the last a few years. To train an LLM, one needs to alternatingly run `forward' computations and `backward' computations. The forward computation can be viewed as attention function evaluation, and the backward computation can be viewed as a gradient computation. In previous work by [Alman and Song, NeurIPS 2023], it was proved that the forward step can be performed in almost-linear time in certain parameter regimes, but that there is no truly sub-quadratic time algorithm in the remaining parameter regimes unless the popular hypothesis SETH is false. In this work, we show nearly identical results for the harder-seeming problem of computing the gradient of loss function of one layer attention network, and thus for the entire process of LLM training. This completely characterizes the fine-grained complexity of every step of LLM training.

The Fine-Grained Complexity of Gradient Computation for Training Large Language Models

TL;DR

This work establishes a precise fine-grained complexity threshold for gradient computation in training large language models by analyzing the Approximate Attention Loss Gradient Computation problem . It proves that the same threshold observed for forward attention steps under applies to the backward gradient, with near-linear time algorithms possible when parameter magnitudes are small () and strong hardness when large (). The upper bound leverages a tensor-trick and low-rank decompositions to reduce the gradient computation to time, while the lower bound follows from a reduction from Attention Computation. Collectively, the results illuminate the fundamental limits and guide practical optimization strategies for efficient LLM training by controlling the magnitude of matrix entries in attention modules.

Abstract

Large language models (LLMs) have made fundamental contributions over the last a few years. To train an LLM, one needs to alternatingly run `forward' computations and `backward' computations. The forward computation can be viewed as attention function evaluation, and the backward computation can be viewed as a gradient computation. In previous work by [Alman and Song, NeurIPS 2023], it was proved that the forward step can be performed in almost-linear time in certain parameter regimes, but that there is no truly sub-quadratic time algorithm in the remaining parameter regimes unless the popular hypothesis SETH is false. In this work, we show nearly identical results for the harder-seeming problem of computing the gradient of loss function of one layer attention network, and thus for the entire process of LLM training. This completely characterizes the fine-grained complexity of every step of LLM training.
Paper Structure (36 sections, 21 theorems, 75 equations)

This paper contains 36 sections, 21 theorems, 75 equations.

Key Result

Theorem 1.5

Assuming $\mathsf{SETH}$, there is no algorithm running in time $O(n^{2-q})$ for any $q>0$ for the $\mathsf{AAttLGC}(n, d = O( \log n ), B = \omega(\sqrt{\log n} ))$ (see Definition def:AAttLGC).

Theorems & Definitions (52)

  • Definition 1.1: $\ell$-th layer forward computation
  • Definition 1.2: Attention optimization
  • Remark 1.3
  • Definition 1.4: Approximate Attention Loss Gradient Computation ($\mathsf{AAttLGC}(n,d,\epsilon)$)
  • Theorem 1.5: Main result, Lower bound, informal version of Theorem \ref{['thm:mainlb:formal']}
  • Theorem 1.6: Main result, Upper bound, informal version of Theorem \ref{['thm:mainalg:formal']}
  • Definition 3.1
  • Definition 3.4
  • Definition 3.5
  • Definition 3.6
  • ...and 42 more