Table of Contents
Fetching ...

Trainable Transformer in Transformer

Abhishek Panigrahi, Sadhika Malladi, Mengzhou Xia, Sanjeev Arora

TL;DR

Trainable Transformer in Transformer (TinT) presents a compact simulator that can implicitly train a large auxiliary transformer inside a smaller model during inference. By encoding auxiliary weights as prefix embeddings and using stacking, sharding, and efficient aggregation, TinT performs forward passes and approximate gradient updates to fine-tune the auxiliary in a single inference pass, with under 2B parameters. Empirical results on language modeling and in-context learning show TinT achieving substantial gains over the base auxiliary model and approaching or matching a larger pre-trained model on several tasks, underscoring the potential for dynamic internal adaptation in large LMs. The work offers a modular codebase and design principles for building efficient, internal gradient-descent simulators, with important implications for interpretability and AI alignment.

Abstract

Recent works attribute the capability of in-context learning (ICL) in large pre-trained language models to implicitly simulating and fine-tuning an internal model (e.g., linear or 2-layer MLP) during inference. However, such constructions require large memory overhead, which makes simulation of more sophisticated internal models intractable. In this work, we propose an efficient construction, Transformer in Transformer (in short, TinT), that allows a transformer to simulate and fine-tune complex models internally during inference (e.g., pre-trained language models). In particular, we introduce innovative approximation techniques that allow a TinT model with less than 2 billion parameters to simulate and fine-tune a 125 million parameter transformer model within a single forward pass. TinT accommodates many common transformer variants and its design ideas also improve the efficiency of past instantiations of simple models inside transformers. We conduct end-to-end experiments to validate the internal fine-tuning procedure of TinT on various language modeling and downstream tasks. For example, even with a limited one-step budget, we observe TinT for a OPT-125M model improves performance by 4-16% absolute on average compared to OPT-125M. These findings suggest that large pre-trained language models are capable of performing intricate subroutines. To facilitate further work, a modular and extensible codebase for TinT is included.

Trainable Transformer in Transformer

TL;DR

Trainable Transformer in Transformer (TinT) presents a compact simulator that can implicitly train a large auxiliary transformer inside a smaller model during inference. By encoding auxiliary weights as prefix embeddings and using stacking, sharding, and efficient aggregation, TinT performs forward passes and approximate gradient updates to fine-tune the auxiliary in a single inference pass, with under 2B parameters. Empirical results on language modeling and in-context learning show TinT achieving substantial gains over the base auxiliary model and approaching or matching a larger pre-trained model on several tasks, underscoring the potential for dynamic internal adaptation in large LMs. The work offers a modular codebase and design principles for building efficient, internal gradient-descent simulators, with important implications for interpretability and AI alignment.

Abstract

Recent works attribute the capability of in-context learning (ICL) in large pre-trained language models to implicitly simulating and fine-tuning an internal model (e.g., linear or 2-layer MLP) during inference. However, such constructions require large memory overhead, which makes simulation of more sophisticated internal models intractable. In this work, we propose an efficient construction, Transformer in Transformer (in short, TinT), that allows a transformer to simulate and fine-tune complex models internally during inference (e.g., pre-trained language models). In particular, we introduce innovative approximation techniques that allow a TinT model with less than 2 billion parameters to simulate and fine-tune a 125 million parameter transformer model within a single forward pass. TinT accommodates many common transformer variants and its design ideas also improve the efficiency of past instantiations of simple models inside transformers. We conduct end-to-end experiments to validate the internal fine-tuning procedure of TinT on various language modeling and downstream tasks. For example, even with a limited one-step budget, we observe TinT for a OPT-125M model improves performance by 4-16% absolute on average compared to OPT-125M. These findings suggest that large pre-trained language models are capable of performing intricate subroutines. To facilitate further work, a modular and extensible codebase for TinT is included.
Paper Structure (80 sections, 12 theorems, 80 equations, 9 figures, 6 tables)

This paper contains 80 sections, 12 theorems, 80 equations, 9 figures, 6 tables.

Key Result

Theorem 1.1

Consider an auxiliary transformer with $L$ layers, $D_{\text{aux}}$ embedding dimension, $H_{\text{aux}}$ attention heads, and a maximum sequence length of $T_{\text{aux}}$. Given a hyperparameter $S$ (see sec:stack), TinT can perform an efficient forward pass (sec:exposition_linear_forward), comput parameters, with constants $c_1, c_2, c_3 < 150$. The TinT model has $D_{\text{sim}} = SD_{\text{au

Figures (9)

  • Figure 1: The overall structure of TinT (see \ref{['sec:design']} for an overview). Each forward, backward, and descent module is represented using combinations of linear, self-attention, layernorm, and activation layers. The input consists of prefix embeddings (\ref{['def:prefix_embs']}) that represent relevant auxiliary model parameters in each layer followed by natural language input. A prefix mask separates the train and test segments of the input (§\ref{['sec:input_structure']}).
  • Figure 2: TinT simulates the forward pass of a linear layer with a $H_{\text{sim}}$-head attention layer ($H_{\text{sim}}=6$ here). We stack $S$ weights per prefix embedding to reduce the number of prefix embeddings required ($S=2$ here). We furthermore shard each weight and token embedding ${\bm{x}}_t$ into $S'$ shards and compute inner products of each shared in parallel using $S \times S'$ attention heads ($S'=3$ here). Please see \ref{['sec:stack']}.
  • Figure 3: Different settings in few-shot learning ($k=3$) using TinT. The Single mode (left) treats each example as a training datapoint, and the auxiliary model is updated with a batch of inputs (see def. \ref{['def:finetune']}). The Multi. mode (right) concatenates all examples to form a single input and uses batch size $1$ in def. \ref{['def:finetune']}. For Label loss, only underlined label words are used as training signal, while full context loss includes all tokens.
  • Figure 4: TinT simulates the backward pass of a linear layer as a $H$-head attention layer ($H=6$ pictured), with the gradient of the loss w.r.t. linear layer output ($\partial_{{\bm{y}}_t}$) as the query, the positional one-hot vector of prefix embeddings as the key, and the parameters of the auxiliary model stored in the prefix embeddings as the value. Similar to the Linear Forward module (\ref{['fig:linear_forward']}), we distribute the dot product computations across all attention heads by sharding the vectors into $S'$ ($S'=3$ here) parts. We omitted the identical transformation for query, and value matrices, and permutation-based transformation for key matrix for illustration purposes.
  • Figure 5: TinT computes the parameter gradients for a linear layer as a $H$-head attention layer ($H=6$ pictured), with the gradient of the loss w.r.t. linear layer output ($\partial_{{\bm{y}}_t}$) as the query, the positional one-hot vector of prefix embeddings as the key, and the input to the linear layer (${\bm{x}}_t$) as the value. The auxiliary model parameters in the prefix embeddings are then updated using a residual connection. Similar to the Linear Forward module (\ref{['fig:linear_forward']}), we distribute the dot product computations across all attention heads, by sharding the vectors into $S'$ ($S'=3$ here) parts. We omitted the identical transformation for query, and value matrices, and permutation-based transformation for key matrix for simplicity.
  • ...and 4 more figures

Theorems & Definitions (63)

  • Theorem 1.1
  • Definition 2.1: Prefix Embeddings
  • Definition 3.1: Linear layer
  • Definition 4.1: Layer normalization
  • Definition 5.1: $N$-step Fine-Tuning
  • Definition 2.1: Auxiliary model softmax self-attention
  • Definition 3.1: TinT's self-attention with $H_{\text{sim}}$ heads
  • Theorem 3.2
  • proof
  • Lemma 3.3
  • ...and 53 more