Table of Contents
Fetching ...

General-Purpose In-Context Learning by Meta-Learning Transformers

Louis Kirsch, James Harrison, Jascha Sohl-Dickstein, Luke Metz

TL;DR

This work tackles the problem of deriving a truly general-purpose in-context learning algorithm by meta-training a black-box Transformer on a richly augmented task distribution. The GPICL framework demonstrates that, given sufficient task diversity and memory capacity, a model can transition from memorizing tasks to identifying tasks and finally to general learning-to-learn that generalizes to unseen datasets. Key contributions include a detailed analysis of how memory/state bottlenecks constrain meta-learning, the observation of algorithmic transitions, and practical interventions (batch size, meta-optimizer tweaks, curricula) that improve meta-training and generalization. The findings have practical implications for data-driven meta-learning, cross-domain generalization, and potential enhancements to in-context learning in large language models.

Abstract

Modern machine learning requires system designers to specify aspects of the learning pipeline, such as losses, architectures, and optimizers. Meta-learning, or learning-to-learn, instead aims to learn those aspects, and promises to unlock greater capabilities with less manual effort. One particularly ambitious goal of meta-learning is to train general-purpose in-context learning algorithms from scratch, using only black-box models with minimal inductive bias. Such a model takes in training data, and produces test-set predictions across a wide range of problems, without any explicit definition of an inference model, training loss, or optimization algorithm. In this paper we show that Transformers and other black-box models can be meta-trained to act as general-purpose in-context learners. We characterize transitions between algorithms that generalize, algorithms that memorize, and algorithms that fail to meta-train at all, induced by changes in model size, number of tasks, and meta-optimization. We further show that the capabilities of meta-trained algorithms are bottlenecked by the accessible state size (memory) determining the next prediction, unlike standard models which are thought to be bottlenecked by parameter count. Finally, we propose practical interventions such as biasing the training distribution that improve the meta-training and meta-generalization of general-purpose in-context learning algorithms.

General-Purpose In-Context Learning by Meta-Learning Transformers

TL;DR

This work tackles the problem of deriving a truly general-purpose in-context learning algorithm by meta-training a black-box Transformer on a richly augmented task distribution. The GPICL framework demonstrates that, given sufficient task diversity and memory capacity, a model can transition from memorizing tasks to identifying tasks and finally to general learning-to-learn that generalizes to unseen datasets. Key contributions include a detailed analysis of how memory/state bottlenecks constrain meta-learning, the observation of algorithmic transitions, and practical interventions (batch size, meta-optimizer tweaks, curricula) that improve meta-training and generalization. The findings have practical implications for data-driven meta-learning, cross-domain generalization, and potential enhancements to in-context learning in large language models.

Abstract

Modern machine learning requires system designers to specify aspects of the learning pipeline, such as losses, architectures, and optimizers. Meta-learning, or learning-to-learn, instead aims to learn those aspects, and promises to unlock greater capabilities with less manual effort. One particularly ambitious goal of meta-learning is to train general-purpose in-context learning algorithms from scratch, using only black-box models with minimal inductive bias. Such a model takes in training data, and produces test-set predictions across a wide range of problems, without any explicit definition of an inference model, training loss, or optimization algorithm. In this paper we show that Transformers and other black-box models can be meta-trained to act as general-purpose in-context learners. We characterize transitions between algorithms that generalize, algorithms that memorize, and algorithms that fail to meta-train at all, induced by changes in model size, number of tasks, and meta-optimization. We further show that the capabilities of meta-trained algorithms are bottlenecked by the accessible state size (memory) determining the next prediction, unlike standard models which are thought to be bottlenecked by parameter count. Finally, we propose practical interventions such as biasing the training distribution that improve the meta-training and meta-generalization of general-purpose in-context learning algorithms.
Paper Structure (74 sections, 2 equations, 23 figures, 2 tables, 1 algorithm)

This paper contains 74 sections, 2 equations, 23 figures, 2 tables, 1 algorithm.

Figures (23)

  • Figure 1: Our General-Purpose In-Context Learner (GPICL) is based on the vanilla Transformer which is trained to make predictions for queries $x'$ given any prefix of a dataset $D := \{x_i, y_i\}_{i=1}^{N_D}$ as in \ref{['eq:multi_task']}.
  • Figure 2: GPICL is able to generalize to unseen tasks. Each cell is a separate meta-training run. (a) An MLP classifier trained in a multi-task fashion across various numbers of tasks (generated based on MNIST) and network sizes is able to fit linearly more tasks, the larger its capacity. (b) A sequence model (here the GPICL Transformer) that observes a dataset $D$ of inputs and labels transitions into generalizing to an seemingly unbounded number of tasks with an increase in model size. This is achieved by switching from a memorization solution to a learning solution that (c) generalizes to unseen tasks. This generalization does not occur with the MLP.
  • Figure 3: GPICL learns from examples at test time, and generalizes to unseen tasks and datasets. We meta-trained the Transformer on a set of tasks defined by random transformations of either MNIST (blue) or FashionMNIST (orange). We then meta-test on unseen tasks, and seen (ab) or unseen (ba) datasets. The plot shows the accuracy averaged across multiple runs at each inner step, with shading indicating $95\%$ confidence intervals. The increase in performance at each step suggests we have learned a learning algorithm.
  • Figure 4: Transformers exhibit three different phases in terms of meta-learned behavior. (1) When training on a small number of tasks, tasks are memorized. (2) Tasks from the training distribution are identified, which is evident as a within-sequence increase of performance. (3) When training across many tasks, we discover a learning algorithm that generalizes to unseen tasks and unseen datasets.
  • Figure 5: The state size (accessible memory) of an architecture most strongly predicts its performance as a general-purpose learning algorithm. (a) A large state is crucial for learning-to-learn to emerge. (b) The parameter count correlates less well with learning capabilities.
  • ...and 18 more figures