Trainable Transformer in Transformer
Abhishek Panigrahi, Sadhika Malladi, Mengzhou Xia, Sanjeev Arora
TL;DR
Trainable Transformer in Transformer (TinT) presents a compact simulator that can implicitly train a large auxiliary transformer inside a smaller model during inference. By encoding auxiliary weights as prefix embeddings and using stacking, sharding, and efficient aggregation, TinT performs forward passes and approximate gradient updates to fine-tune the auxiliary in a single inference pass, with under 2B parameters. Empirical results on language modeling and in-context learning show TinT achieving substantial gains over the base auxiliary model and approaching or matching a larger pre-trained model on several tasks, underscoring the potential for dynamic internal adaptation in large LMs. The work offers a modular codebase and design principles for building efficient, internal gradient-descent simulators, with important implications for interpretability and AI alignment.
Abstract
Recent works attribute the capability of in-context learning (ICL) in large pre-trained language models to implicitly simulating and fine-tuning an internal model (e.g., linear or 2-layer MLP) during inference. However, such constructions require large memory overhead, which makes simulation of more sophisticated internal models intractable. In this work, we propose an efficient construction, Transformer in Transformer (in short, TinT), that allows a transformer to simulate and fine-tune complex models internally during inference (e.g., pre-trained language models). In particular, we introduce innovative approximation techniques that allow a TinT model with less than 2 billion parameters to simulate and fine-tune a 125 million parameter transformer model within a single forward pass. TinT accommodates many common transformer variants and its design ideas also improve the efficiency of past instantiations of simple models inside transformers. We conduct end-to-end experiments to validate the internal fine-tuning procedure of TinT on various language modeling and downstream tasks. For example, even with a limited one-step budget, we observe TinT for a OPT-125M model improves performance by 4-16% absolute on average compared to OPT-125M. These findings suggest that large pre-trained language models are capable of performing intricate subroutines. To facilitate further work, a modular and extensible codebase for TinT is included.
