Table of Contents
Fetching ...

A Meta-Learning Perspective on Transformers for Causal Language Modeling

Xinbo Wu, Lav R. Varshney

TL;DR

A meta-learning view of the Transformer architecture when trained for the causal language modeling task is established by explicating an inner optimization process within the Transformer, and a special characteristic of the norms of learned token representations within Transformer-based causal language models is discovered.

Abstract

The Transformer architecture has become prominent in developing large causal language models. However, mechanisms to explain its capabilities are not well understood. Focused on the training process, here we establish a meta-learning view of the Transformer architecture when trained for the causal language modeling task, by explicating an inner optimization process within the Transformer. Further, within the inner optimization, we discover and theoretically analyze a special characteristic of the norms of learned token representations within Transformer-based causal language models. Our analysis is supported by experiments in various settings.

A Meta-Learning Perspective on Transformers for Causal Language Modeling

TL;DR

A meta-learning view of the Transformer architecture when trained for the causal language modeling task is established by explicating an inner optimization process within the Transformer, and a special characteristic of the norms of learned token representations within Transformer-based causal language models is discovered.

Abstract

The Transformer architecture has become prominent in developing large causal language models. However, mechanisms to explain its capabilities are not well understood. Focused on the training process, here we establish a meta-learning view of the Transformer architecture when trained for the causal language modeling task, by explicating an inner optimization process within the Transformer. Further, within the inner optimization, we discover and theoretically analyze a special characteristic of the norms of learned token representations within Transformer-based causal language models. Our analysis is supported by experiments in various settings.
Paper Structure (26 sections, 1 theorem, 21 equations, 11 figures, 7 tables)

This paper contains 26 sections, 1 theorem, 21 equations, 11 figures, 7 tables.

Key Result

Proposition 1

Let ${W \in \mathbb{R}^{d_{in} \times d_{out}}}$ be a linear transformation matrix, and ${x \in \mathbb{R}^{d_{in}}}$ and ${y \in \mathbb{R}^{d_{out}}}$ be the input and output of the linear transformation. We can have the Gram matrix of ${W}$ decomposed by eigendecomposition as ${W^\mathsf{T}W = U\

Figures (11)

  • Figure 1: Clustering analysis on the validation set across different layers throughout the training process. Different columns indicate different ground truth labels: seed, next token, and their combination (Seed-Next-Token). The legend shows layers. We only present results on F1 score here and refer to results on other evaluation metrics and other data split in Figure \ref{['fig:learning_dynamics_all']} in the Appendix. Each dot illustrates a data point.
  • Figure 2: Bi-level Optimization process within a Transformer-based CLM model. (a) shows inner optimization losses across layers and training epochs computed according to \ref{['inner_loss_equation']} and aggregated from training examples. (b) illustrates the losses of the outer optimization process throughout the training process, which explicitly optimizes for the CLM task. The outer loss is identical to the training loss of the model.
  • Figure 3: Attention over the history throughout the training evaluated using seed identity as ground truth: we compute the attentions from an unseen instance (the last token of each sequence) from the validation set to all of the historical instances and calculate the percentage of the top 10 and last 10 attended instances having the same label. The final measurement is aggregated across different heads, token positions, and instances and is reported per layer. See the same analysis based on other ground truths in Appendix \ref{['fig:att_lout=1']}
  • Figure 4: Inner Optimization loss across layers based on The Wikitext-103 Dataset: the x-axis is layer numbers and y-axis is the inner loss averaged over samples. The mean values of different samples are shown as circles.
  • Figure 5: Visuzalizations of current token vector representations for different models on the Wikitext-103 dataset.
  • ...and 6 more figures

Theorems & Definitions (4)

  • Conjecture 8.1
  • Proposition 1
  • Conjecture B.1
  • Proof 1