Table of Contents
Fetching ...

Breaking through the learning plateaus of in-context learning in Transformer

Jingwen Fu, Tao Yang, Yuwang Wang, Yan Lu, Nanning Zheng

TL;DR

This work investigates why Transformers exhibit learning plateaus in in-context learning and introduces a conceptual split of internal representations into a weights component and a context component. Using a controllable Shapes3D-based synthetic task, the authors show that plateau duration correlates with dysfunction in the weights component, and they demonstrate three methods—weights warm-up, mixed training, and an extra weights loss—to accelerate learning without increasing model size. The proposed framework is validated through probes measuring component quality and extended to NLP tasks, where weights-focused interventions continue to improve in-context learning. The findings suggest an eco-friendly path to endowing AI systems with robust in-context learning by directly enhancing the weights component rather than scaling up models.

Abstract

In-context learning, i.e., learning from context examples, is an impressive ability of Transformer. Training Transformers to possess this in-context learning skill is computationally intensive due to the occurrence of learning plateaus, which are periods within the training process where there is minimal or no enhancement in the model's in-context learning capability. To study the mechanism behind the learning plateaus, we conceptually seperate a component within the model's internal representation that is exclusively affected by the model's weights. We call this the "weights component", and the remainder is identified as the "context component". By conducting meticulous and controlled experiments on synthetic tasks, we note that the persistence of learning plateaus correlates with compromised functionality of the weights component. Recognizing the impaired performance of the weights component as a fundamental behavior drives learning plateaus, we have developed three strategies to expedite the learning of Transformers. The effectiveness of these strategies is further confirmed in natural language processing tasks. In conclusion, our research demonstrates the feasibility of cultivating a powerful in-context learning ability within AI systems in an eco-friendly manner.

Breaking through the learning plateaus of in-context learning in Transformer

TL;DR

This work investigates why Transformers exhibit learning plateaus in in-context learning and introduces a conceptual split of internal representations into a weights component and a context component. Using a controllable Shapes3D-based synthetic task, the authors show that plateau duration correlates with dysfunction in the weights component, and they demonstrate three methods—weights warm-up, mixed training, and an extra weights loss—to accelerate learning without increasing model size. The proposed framework is validated through probes measuring component quality and extended to NLP tasks, where weights-focused interventions continue to improve in-context learning. The findings suggest an eco-friendly path to endowing AI systems with robust in-context learning by directly enhancing the weights component rather than scaling up models.

Abstract

In-context learning, i.e., learning from context examples, is an impressive ability of Transformer. Training Transformers to possess this in-context learning skill is computationally intensive due to the occurrence of learning plateaus, which are periods within the training process where there is minimal or no enhancement in the model's in-context learning capability. To study the mechanism behind the learning plateaus, we conceptually seperate a component within the model's internal representation that is exclusively affected by the model's weights. We call this the "weights component", and the remainder is identified as the "context component". By conducting meticulous and controlled experiments on synthetic tasks, we note that the persistence of learning plateaus correlates with compromised functionality of the weights component. Recognizing the impaired performance of the weights component as a fundamental behavior drives learning plateaus, we have developed three strategies to expedite the learning of Transformers. The effectiveness of these strategies is further confirmed in natural language processing tasks. In conclusion, our research demonstrates the feasibility of cultivating a powerful in-context learning ability within AI systems in an eco-friendly manner.
Paper Structure (45 sections, 5 theorems, 20 equations, 10 figures, 1 table)

This paper contains 45 sections, 5 theorems, 20 equations, 10 figures, 1 table.

Key Result

Proposition 3.2

The probability of ${\mathbb{P}}(y_p|{\mathbf{x}}_p,s_c)$ can be decomposite as: where ${\mathbb{P}}(v_p|{\mathbf{x}}_p)$ is weights related information, and ${\mathbb{P}}(e_h|s_c,m){\mathbb{P}}(m|s_c)$ is context related information. ${\mathbb{P}}(y_p|v_p,m,e_h)$ is related for the properties of task, and we have ${\mathbb{P}}(y_p|v_p,m,e_h)=1$ if $m(v_p^{e_h})=y_p$ else ${\mat

Figures (10)

  • Figure 1: A: Examples of the in-context learning tasks. Examples of (1) comes from alayrac2022flamingo, Examples of (2),(3),(4) come from brown2020language. B: Illustration of learning plateaus and transition pattern. We evaluate the in-context learning ability of Pythia 13B model biderman2023pythia trained on pile dataset gao2020pile using WordSelection task (Detail in Appendix \ref{['sec:nlp_app']}) during the training process.
  • Figure 2: Synthetic task. In the task, Transformer is required to predict the label of ${\mathbf{x}}_p$ given context examples $s_c$. The images from the 3D Shapes dataset are synthesized based on six factors. The output factor is determined by the context. In this case, we provide two sequences of factors: "object color" and "object shape," respectively.
  • Figure 3: Learning plateus.A. We reproduce the learning plateaus and transition pattern in our synthetic task, similar to Fig \ref{['fig:examples']}B. B. The length of learning plateaus increase with the complexity of the task measured by entropy of ${\mathbb{P}}(m)$.
  • Figure 4: Weights component and learning plateaus. A. The weights component score is increasing under $D_\text{fix} \Rightarrow D_\text{fix}$, while the weights component score is descreasing under the $D_\text{rnd} \Rightarrow D_\text{rnd}$ setting. Note that the “Weights” and “Context” in the figure are short for weights comp. score and context comp. score respectively. B. The weights component score after 50 epoch training decreases when increasing the complexity of the task. The dashed green line indicates the weights component score at the initialization point. C. The weights component score at 50 epoch negative correlates with the length of learning plateaus. The dashed green line indicates the weights component score at the initialization point.
  • Figure 5: Three methods are proposed to assist in overcoming learning plateaus. A: Effective of the warm-up method.Top: Employing $D_{\text{fix}}$ as a warm-up for the Transformer significantly mitigates learning plateaus. The dashed line indicates the transition point from $D_{\text{fix}}$ to $D_{\text{rnd}}$. Bottom: We execute the transition from $D_{\text{fix} \to \text{rnd}} \Rightarrow D_{\text{rnd}}$ at various switching points. The curve labeled "2" signifies the switch from $D_{\text{fix}}$ to $D_{\text{rnd}}$ at epoch 2. The curve labeled "0" serves as the baseline, that is, $D_{\text{rnd}} \Rightarrow D_{\text{rnd}}$. The dashed lines highlight the respective switching points. B: Combining $D_{\text{fix}}$ and $D_{\text{rnd}}$.Top: Mixed training substantially improves the weights component score during the learning process and eliminates learning plateaus. Bottom: Boosting the weights component can promote the development of in-context learning capabilities in smaller models. The dashed line depicts the task configuration $D_{\text{fix} \land \text{rnd}} \Rightarrow D_{\text{rnd}}$, while the solid line represents the $D_{\text{rnd}}\Rightarrow D_{\text{rnd}}$ setting. C: Extra Loss.Top: Incorporating a weights loss can significantly enhance learning, whereas adding context loss does not have a noticeable impact. The baseline is $D_{\text{rnd}} \Rightarrow D_{\text{rnd}}$. Bottom: With the weights loss, the Transformer can attain a commendable weights component score after 50 epochs of training. The green dashed line indicates the weights comp. score at the initialization point.
  • ...and 5 more figures

Theorems & Definitions (12)

  • Definition 3.1
  • Proposition 3.2
  • Definition 3.3
  • Proposition 4.1
  • proof
  • Definition 5.1
  • Definition 5.2
  • Proposition 5.3
  • Lemma 5.4
  • proof
  • ...and 2 more