Table of Contents
Fetching ...

Bypassing the Exponential Dependency: Looped Transformers Efficiently Learn In-context by Multi-step Gradient Descent

Bo Chen, Xiaoyu Li, Yingyu Liang, Zhenmei Shi, Zhao Song

TL;DR

This work tackles how in-context learning (ICL) can be realized by looped Transformers without updating parameters, challenging prior claims that multi-step gradient descent requires exponentially many in-context examples. Focusing on linear looped Transformers for linear vector generation tasks, the authors prove that efficient multi-step gradient descent is achievable when the input data have a constant condition number, with the prediction error decaying as $|\alpha|\exp(-T/(2\kappa))$ and requiring only $n = O(d)$ examples under Gaussian-like assumptions. They formalize a gradient-descent interpretation of the looped transformer, derive exact gradient computations, and provide a tight convergence bound, complemented by preliminary experiments that validate the theory. The results imply a stronger intrinsic in-context learning capability in Transformers and offer practical guidance for designing more efficient inference mechanisms for large language models.

Abstract

In-context learning has been recognized as a key factor in the success of Large Language Models (LLMs). It refers to the model's ability to learn patterns on the fly from provided in-context examples in the prompt during inference. Previous studies have demonstrated that the Transformer architecture used in LLMs can implement a single-step gradient descent update by processing in-context examples in a single forward pass. Recent work has further shown that, during in-context learning, a looped Transformer can implement multi-step gradient descent updates in forward passes. However, their theoretical results require an exponential number of in-context examples, $n = \exp(Ω(T))$, where $T$ is the number of loops or passes, to achieve a reasonably low error. In this paper, we study linear looped Transformers in-context learning on linear vector generation tasks. We show that linear looped Transformers can implement multi-step gradient descent efficiently for in-context learning. Our results demonstrate that as long as the input data has a constant condition number, e.g., $n = O(d)$, the linear looped Transformers can achieve a small error by multi-step gradient descent during in-context learning. Furthermore, our preliminary experiments validate our theoretical analysis. Our findings reveal that the Transformer architecture possesses a stronger in-context learning capability than previously understood, offering new insights into the mechanisms behind LLMs and potentially guiding the better design of efficient inference algorithms for LLMs.

Bypassing the Exponential Dependency: Looped Transformers Efficiently Learn In-context by Multi-step Gradient Descent

TL;DR

This work tackles how in-context learning (ICL) can be realized by looped Transformers without updating parameters, challenging prior claims that multi-step gradient descent requires exponentially many in-context examples. Focusing on linear looped Transformers for linear vector generation tasks, the authors prove that efficient multi-step gradient descent is achievable when the input data have a constant condition number, with the prediction error decaying as and requiring only examples under Gaussian-like assumptions. They formalize a gradient-descent interpretation of the looped transformer, derive exact gradient computations, and provide a tight convergence bound, complemented by preliminary experiments that validate the theory. The results imply a stronger intrinsic in-context learning capability in Transformers and offer practical guidance for designing more efficient inference mechanisms for large language models.

Abstract

In-context learning has been recognized as a key factor in the success of Large Language Models (LLMs). It refers to the model's ability to learn patterns on the fly from provided in-context examples in the prompt during inference. Previous studies have demonstrated that the Transformer architecture used in LLMs can implement a single-step gradient descent update by processing in-context examples in a single forward pass. Recent work has further shown that, during in-context learning, a looped Transformer can implement multi-step gradient descent updates in forward passes. However, their theoretical results require an exponential number of in-context examples, , where is the number of loops or passes, to achieve a reasonably low error. In this paper, we study linear looped Transformers in-context learning on linear vector generation tasks. We show that linear looped Transformers can implement multi-step gradient descent efficiently for in-context learning. Our results demonstrate that as long as the input data has a constant condition number, e.g., , the linear looped Transformers can achieve a small error by multi-step gradient descent during in-context learning. Furthermore, our preliminary experiments validate our theoretical analysis. Our findings reveal that the Transformer architecture possesses a stronger in-context learning capability than previously understood, offering new insights into the mechanisms behind LLMs and potentially guiding the better design of efficient inference algorithms for LLMs.

Paper Structure

This paper contains 21 sections, 13 theorems, 38 equations, 1 figure, 1 table.

Key Result

Theorem 1.1

Let $T$ be the number of loops, $n$ be the number of in-context examples, and $d$ be the feature dimension. Let $(X, y) \in \mathbb{R}^{n \times d} \times \mathbb{R}^{n}$ be the in-context examples. Let $\kappa$ be the condition number of $X^\top X$ (Definition def:condition). Then, given a query $\

Figures (1)

  • Figure 1: The convergence rate comparison for gradient descent in linear vector generation with a fixed dimension $d=4$ and varying sample sizes $n \in \{ 16, 32, 64, 128 \}$ and their corresponding condition number $\kappa$. The 'Emp' means the empirical error of our experiments. The 'Theory' means the theoretical bound in Theorem \ref{['thm:main']}. The $y$-axis is the logarithm of normalized error and the $x$-axis is the number of loops $T$. Both empirical (solid lines) and theoretical (dashed lines) results are presented for each $n$. The plot demonstrates that as the sample size $n$ increases, the condition number will decrease, so the convergence rate improves. Thus, with larger $n$ values, there will be a steeper slope and faster convergence to the optimal solution.

Theorems & Definitions (33)

  • Theorem 1.1: Main result. Informal version of Theorem \ref{['thm:main']}
  • Definition 3.1: In-context prompt data
  • Definition 3.2: In-context prompt label
  • Definition 3.3: In-context task
  • Definition 3.4: Input data
  • Definition 3.5: Linear attention
  • Definition 3.6: Linear looped transformer
  • Remark 3.7
  • Definition 3.8: Linear regression
  • Definition 3.9
  • ...and 23 more