Table of Contents
Fetching ...

CausalLM is not optimal for in-context learning

Nan Ding, Tomer Levinboim, Jialin Wu, Sebastian Goodman, Radu Soricut

TL;DR

The paper provides a theoretical and empirical comparison of prefixLM and causalLM in in-context learning, showing that multi-layer linear-self-attention (LSA) implementations lead prefixLM to converge to the least-squares optimum, while causalLM induces per-position online-gradient-descent dynamics whose stationary points may remain suboptimal. It proves linear convergence for both architectures but establishes a fundamental difference in the quality of the stationary point, with prefixLM achieving optimal regression solutions. The authors validate the theory through synthetic experiments with LSA-transformers and standard transformers, as well as LM and multimodal models (T5X, PaLM2 PaLI-X), consistently observing superior ICL performance for prefixLM. The results offer a principled explanation for the empirical advantages of prefix-style attention in ICL and suggest practical implications for pretraining and model design.

Abstract

Recent empirical evidence indicates that transformer based in-context learning performs better when using a prefix language model (prefixLM), in which in-context samples can all attend to each other, compared to causal language models (causalLM), which use auto-regressive attention that prohibits in-context samples to attend to future samples. While this result is intuitive, it is not understood from a theoretical perspective. In this paper we take a theoretical approach and analyze the convergence behavior of prefixLM and causalLM under a certain parameter construction. Our analysis shows that both LM types converge to their stationary points at a linear rate, but that while prefixLM converges to the optimal solution of linear regression, causalLM convergence dynamics follows that of an online gradient descent algorithm, which is not guaranteed to be optimal even as the number of samples grows infinitely. We supplement our theoretical claims with empirical experiments over synthetic and real tasks and using various types of transformers. Our experiments verify that causalLM consistently underperforms prefixLM in all settings.

CausalLM is not optimal for in-context learning

TL;DR

The paper provides a theoretical and empirical comparison of prefixLM and causalLM in in-context learning, showing that multi-layer linear-self-attention (LSA) implementations lead prefixLM to converge to the least-squares optimum, while causalLM induces per-position online-gradient-descent dynamics whose stationary points may remain suboptimal. It proves linear convergence for both architectures but establishes a fundamental difference in the quality of the stationary point, with prefixLM achieving optimal regression solutions. The authors validate the theory through synthetic experiments with LSA-transformers and standard transformers, as well as LM and multimodal models (T5X, PaLM2 PaLI-X), consistently observing superior ICL performance for prefixLM. The results offer a principled explanation for the empirical advantages of prefix-style attention in ICL and suggest practical implications for pretraining and model design.

Abstract

Recent empirical evidence indicates that transformer based in-context learning performs better when using a prefix language model (prefixLM), in which in-context samples can all attend to each other, compared to causal language models (causalLM), which use auto-regressive attention that prohibits in-context samples to attend to future samples. While this result is intuitive, it is not understood from a theoretical perspective. In this paper we take a theoretical approach and analyze the convergence behavior of prefixLM and causalLM under a certain parameter construction. Our analysis shows that both LM types converge to their stationary points at a linear rate, but that while prefixLM converges to the optimal solution of linear regression, causalLM convergence dynamics follows that of an online gradient descent algorithm, which is not guaranteed to be optimal even as the number of samples grows infinitely. We supplement our theoretical claims with empirical experiments over synthetic and real tasks and using various types of transformers. Our experiments verify that causalLM consistently underperforms prefixLM in all settings.
Paper Structure (39 sections, 8 theorems, 63 equations, 6 figures, 9 tables)

This paper contains 39 sections, 8 theorems, 63 equations, 6 figures, 9 tables.

Key Result

Proposition 1

For a multi-layer LSA satisfying the construction eq:constructed_lsa and with $\mathop{\mathrm{\mathbf{w}}}\nolimits^{(0)} = 0$, if its input $\mathop{\mathrm{\mathbf{Z}}}\nolimits$ is formatted as eq:icl_input, then its $l$-th layer output is $\mathop{\mathrm{\mathbf{z}}}\nolimits_j^{(l)} = (\matho

Figures (6)

  • Figure 1: The inputs/outputs of a multi-layer in-context learner. We omitted $\mathop{\mathrm{\mathbf{x}}}\nolimits_j$ and $\mathop{\mathrm{\mathbf{x}}}\nolimits_{query}$ since they are unchanged.
  • Figure 2: Left/Middle: the MSE on in-context examples and query examples of multi-layer LSA-based prefixLM/causalLM-ICLs with 40 in-context training examples. Right: the query MSE of causalLM-ICL's stationary points (per Proposition \ref{['prop:causal-icl']}) using up to 300 in-context examples.
  • Figure 3: The test query errors of the 24-layer SL-transformers based prefixLM/causalLM-ICLs on linear regression (left), non-linear regression (middle), and multiclass classification (right).
  • Figure 4: The test error on the stationary point of the causalLM2-ICL with up to 300 in-context examples.
  • Figure 5: The illustration of the attention mask. Green arrows represent the attentions between in-context examples. The dashed arrows only applies for prefixLM. Red arrows represent the attentions from queries to in-context examples. The query examples should not attend to themselves because the inputs do not contain labels.
  • ...and 1 more figures

Theorems & Definitions (8)

  • Proposition 1
  • Proposition 2
  • Proposition 3
  • Proposition 4
  • Proposition 5
  • Proposition 6
  • Proposition 7
  • Proposition 8