Table of Contents
Fetching ...

Learning to (Learn at Test Time)

Yu Sun, Xinhao Li, Karan Dalal, Chloe Hsu, Sanmi Koyejo, Carlos Guestrin, Xiaolong Wang, Tatsunori Hashimoto, Xinlei Chen

TL;DR

The paper reframes supervised learning as a bi-level learning problem, where an inner loop performs test-time training with a self-supervised reconstruction task on each instance and an outer loop learns the self-supervised objective to align with the main task. It establishes theoretical and empirical connections: with linear inner loops it is equivalent to linear attention, with kernel-based inner loops to self-attention, and with neural-network inner loops it can surpass linear-attention baselines under resource constraints. Across ImageNet experiments, MTTT-Linear closely tracks linear-attention while MTTT-MLP offers gains at the cost of higher FLOPs, and in pixel-based experiments, SGD-enabled inner loops achieve notable improvements over traditional ViT baselines. The work suggests a scalable, meta-learning-driven path to more capable architectures, especially when memory and compute permit neural-network inner loops, and outlines a broad program of future directions to broaden applicability and efficiency.

Abstract

We reformulate the problem of supervised learning as learning to learn with two nested loops (i.e. learning problems). The inner loop learns on each individual instance with self-supervision before final prediction. The outer loop learns the self-supervised task used by the inner loop, such that its final prediction improves. Our inner loop turns out to be equivalent to linear attention when the inner-loop learner is only a linear model, and to self-attention when it is a kernel estimator. For practical comparison with linear or self-attention layers, we replace each of them in a transformer with an inner loop, so our outer loop is equivalent to training the architecture. When each inner-loop learner is a neural network, our approach vastly outperforms transformers with linear attention on ImageNet from 224 x 224 raw pixels in both accuracy and FLOPs, while (regular) transformers cannot run.

Learning to (Learn at Test Time)

TL;DR

The paper reframes supervised learning as a bi-level learning problem, where an inner loop performs test-time training with a self-supervised reconstruction task on each instance and an outer loop learns the self-supervised objective to align with the main task. It establishes theoretical and empirical connections: with linear inner loops it is equivalent to linear attention, with kernel-based inner loops to self-attention, and with neural-network inner loops it can surpass linear-attention baselines under resource constraints. Across ImageNet experiments, MTTT-Linear closely tracks linear-attention while MTTT-MLP offers gains at the cost of higher FLOPs, and in pixel-based experiments, SGD-enabled inner loops achieve notable improvements over traditional ViT baselines. The work suggests a scalable, meta-learning-driven path to more capable architectures, especially when memory and compute permit neural-network inner loops, and outlines a broad program of future directions to broaden applicability and efficiency.

Abstract

We reformulate the problem of supervised learning as learning to learn with two nested loops (i.e. learning problems). The inner loop learns on each individual instance with self-supervision before final prediction. The outer loop learns the self-supervised task used by the inner loop, such that its final prediction improves. Our inner loop turns out to be equivalent to linear attention when the inner-loop learner is only a linear model, and to self-attention when it is a kernel estimator. For practical comparison with linear or self-attention layers, we replace each of them in a transformer with an inner loop, so our outer loop is equivalent to training the architecture. When each inner-loop learner is a neural network, our approach vastly outperforms transformers with linear attention on ImageNet from 224 x 224 raw pixels in both accuracy and FLOPs, while (regular) transformers cannot run.
Paper Structure (30 sections, 22 equations, 4 figures, 4 tables)

This paper contains 30 sections, 22 equations, 4 figures, 4 tables.

Figures (4)

  • Figure 1: More inner-loop steps improve accuracy up to $T=4$ (left). Behavior of inner-loop loss mirrors regular (non-meta) learning (right).
  • Figure 2: Illustration of Decoder Layer Norm (LN), presented in Section \ref{['sec:exp']}. This diagram shows the first half of a transformer block, omitting the second half which does not contain any attention layer. The input embedding is $z$, the output is $\hat{z}$. The identity mapping is at the top, and the residual is learned at the bottom. The dotted line from $W_0$ to $W_1$ represents an inner-loop gradient step. Here we use $T=1$, i.e. only one step in the inner loop, so the final prediction is made with $W_1$. Standard design: only use the blue LN. The output of decoder $g$, in this case $z_0$, is expected to reconstruct $x$, causing a "type mismatch". Our design: also use the red LN. Now $\hat{x}$ is expected to reconstruct $x$, and both are outputs of LN.
  • Figure 3: Inner-loop loss across the 12 TTT layers. Behavior across layers is roughly the same as in Figure \ref{['fig:ablate']}. Method: MTTT-MLP performing full-batch gradient descent in the inner loop, $T=4$. Setting: ImageNet from patches. See Subsection \ref{['subsec:patch']}.
  • Figure 4: Inner-loop loss across the 12 TTT layers. Behavior across layers is roughly the same as in Figure \ref{['fig:ablate']}. Method: MTTT-MLP performing stochastic gradient descent in the inner loop, $T=4$. Setting: ImageNet from pixels. See Subsection \ref{['subsec:pixel']}.