Table of Contents
Fetching ...

Learning without training: The implicit dynamics of in-context learning

Benoit Dherin, Michael Munn, Hanna Mazzawi, Michael Wunder, Javier Gonzalvo

TL;DR

The paper analyzes how large language models can learn from in-context information without updating weights by introducing the contextual blocks framework, where a contextual layer paired with an MLP induces an exact rank-1 update to the MLP weights. It provides a formal theorem stating that context can be translated into a low-rank, weight-modifying operation, and shows an implicit gradient-descent-like dynamic as the prompt is consumed. Empirically, the authors validate the theory on a linear-function ICL task, demonstrating that the implicit weight updates reproduce the same outputs as explicit fine-tuning and converge as context accumulates. They further connect the mechanism to model-editing concepts and discuss implications for prompt engineering, context compression, and architecture design.

Abstract

One of the most striking features of Large Language Models (LLMs) is their ability to learn in-context. Namely at inference time an LLM is able to learn new patterns without any additional weight update when these patterns are presented in the form of examples in the prompt, even if these patterns were not seen during training. The mechanisms through which this can happen are still largely unknown. In this work, we show that the stacking of a self-attention layer with an MLP, allows the transformer block to implicitly modify the weights of the MLP layer according to the context. We argue through theory and experimentation that this simple mechanism may be the reason why LLMs can learn in-context and not only during training. Specifically, we show how a transformer block implicitly transforms a context into a low-rank weight-update of its MLP layer.

Learning without training: The implicit dynamics of in-context learning

TL;DR

The paper analyzes how large language models can learn from in-context information without updating weights by introducing the contextual blocks framework, where a contextual layer paired with an MLP induces an exact rank-1 update to the MLP weights. It provides a formal theorem stating that context can be translated into a low-rank, weight-modifying operation, and shows an implicit gradient-descent-like dynamic as the prompt is consumed. Empirically, the authors validate the theory on a linear-function ICL task, demonstrating that the implicit weight updates reproduce the same outputs as explicit fine-tuning and converge as context accumulates. They further connect the mechanism to model-editing concepts and discuss implications for prompt engineering, context compression, and architecture design.

Abstract

One of the most striking features of Large Language Models (LLMs) is their ability to learn in-context. Namely at inference time an LLM is able to learn new patterns without any additional weight update when these patterns are presented in the form of examples in the prompt, even if these patterns were not seen during training. The mechanisms through which this can happen are still largely unknown. In this work, we show that the stacking of a self-attention layer with an MLP, allows the transformer block to implicitly modify the weights of the MLP layer according to the context. We argue through theory and experimentation that this simple mechanism may be the reason why LLMs can learn in-context and not only during training. Specifically, we show how a transformer block implicitly transforms a context into a low-rank weight-update of its MLP layer.

Paper Structure

This paper contains 21 sections, 6 theorems, 63 equations, 6 figures.

Key Result

Theorem 2.2

Consider a contextual block $T_W=M_W\circ A$ as above formed by a contextual layer $A$ composed with a neural network $M_W$ whose first fully-connected layer has weight matrix $W$. Given a context $C$ and an input $x \in C\backslash Y$, the effect of some portion $Y\subset C$ of the context on the o where $\delta \! A_x(Y) := A(C,x) - A(C\backslash Y, x)$ is the context vector associated to $Y$. F

Figures (6)

  • Figure 1: When taking $Y = C$ to be the full context and a query $x$, the corollary to Theorem \ref{['theorem:context2weights']} provides an explicit formula which effectively captures how the effect of the context $C$ is encoded as a weight transfer to the first layer MLP weight $W$ via $\Delta_x W(C)$.
  • Figure 2: Train and Validation loss curves. Here, the "Validation loss (computed via $\Delta W$)" refers the loss computed using $T_{W + \Delta W}$; i.e., the trained model prediction given only $x_{\text{query}}$ but with MLP weights modified by $\Delta W$ as defined in Eq. \ref{['equation:Delta_W']}. Left: Training loss and both validation Loss curves. Middle: Close-up of validation loss computed both ways; i.e., using $T_W(C, x)$ vs. $T_{W + \Delta_x W}(x)$. Right: Once trained, we sample 100 test tasks and for each point $(x_1, x_2) \in \mathbb{R}^{d=2}$ average the difference between $T_W$ and $T_{W + \Delta_x W}$. The two outputs agree on a wide range of both tasks and input values up to an order of $10^{-7}$.
  • Figure 3: Convergence of $(\Delta W)_i$. As more of the context in processed, the relative change in the weights $W$ converges to zero. For context length $i>2$, the plot above represents the average difference $\|(\Delta W)_{i+1} - (\Delta W)_{i}\|_2$ and the standard error over 100 separate trials.
  • Figure 4: Direct finetuning vs implicit weight update. Left: Both finetuning and implicit weight updates minimize the loss in similar ways. Right: The two forms of weight updates remain highly aligned with respect to the normalized Frobenius inner product.
  • Figure 5: Train and Validation loss curves for multi-layer transformer with LayerNorm. Here, the "Validation loss (computed via $\Delta W$)" refers the loss computed using $T_{W + \Delta W}$; i.e., the trained model prediction given only $x_{\text{query}}$ but with MLP weights modified by $\Delta W$ as defined in Eq. \ref{['equation:Delta_W']}. Left: Training loss and both validation Loss curves. Middle: Close-up of validation loss computed both ways; i.e., using $T_W(C, x)$ vs. $T_{W + \Delta_x W}(x)$. Right: Once trained, we sample 100 test tasks $(C, x)$ and for each we perform a forward pass computing both $T_{\mathbf{W}, \mathbf{b}}(C,x)$ and $T_{\mathbf{W + \Delta W}, \mathbf{b + \Delta b}}(x)$. We report the mean and standard error of the L2-norm of the difference of the block outputs for each block in the multi-layer transformer. The block outputs agree with a high degree of precision, up to order $10^{-6}$.
  • ...and 1 more figures

Theorems & Definitions (16)

  • Definition 2.1
  • Theorem 2.2
  • proof
  • Remark 2.3
  • Remark 2.4
  • Remark 2.5
  • Corollary 2.5.1
  • Remark 2.6
  • Corollary 2.6.1
  • Proposition 3.1
  • ...and 6 more