Table of Contents
Fetching ...

Looped Transformers are Better at Learning Learning Algorithms

Liu Yang, Kangwook Lee, Robert Nowak, Dimitris Papailiopoulos

TL;DR

The paper tackles the gap between standard transformers and the iterative nature of classical learning algorithms by introducing a looped transformer that shares parameters across repeated passes to emulate fixed-point iterations. Through a carefully designed training strategy that injects inputs and uses a truncated loss over loop iterations, the model achieves competitive in-context learning performance with far fewer parameters than a conventional transformer. Across linear, sparse linear, decision-tree, and neural-network function classes, the looped transformer matches or surpasses the standard transformer and demonstrates favorable sample efficiency, inductive biases, and robustness under certain out-of-distribution conditions. The work highlights the practical potential of looped architectures for efficient, iteration-aware in-context learning and discusses mathematical implications, trade-offs, and future directions for adaptive looping and generalization beyond training distributions.

Abstract

Transformers have demonstrated effectiveness in in-context solving data-fitting problems from various (latent) models, as reported by Garg et al. However, the absence of an inherent iterative structure in the transformer architecture presents a challenge in emulating the iterative algorithms, which are commonly employed in traditional machine learning methods. To address this, we propose the utilization of looped transformer architecture and its associated training methodology, with the aim of incorporating iterative characteristics into the transformer architectures. Experimental results suggest that the looped transformer achieves performance comparable to the standard transformer in solving various data-fitting problems, while utilizing less than 10% of the parameter count.

Looped Transformers are Better at Learning Learning Algorithms

TL;DR

The paper tackles the gap between standard transformers and the iterative nature of classical learning algorithms by introducing a looped transformer that shares parameters across repeated passes to emulate fixed-point iterations. Through a carefully designed training strategy that injects inputs and uses a truncated loss over loop iterations, the model achieves competitive in-context learning performance with far fewer parameters than a conventional transformer. Across linear, sparse linear, decision-tree, and neural-network function classes, the looped transformer matches or surpasses the standard transformer and demonstrates favorable sample efficiency, inductive biases, and robustness under certain out-of-distribution conditions. The work highlights the practical potential of looped architectures for efficient, iteration-aware in-context learning and discusses mathematical implications, trade-offs, and future directions for adaptive looping and generalization beyond training distributions.

Abstract

Transformers have demonstrated effectiveness in in-context solving data-fitting problems from various (latent) models, as reported by Garg et al. However, the absence of an inherent iterative structure in the transformer architecture presents a challenge in emulating the iterative algorithms, which are commonly employed in traditional machine learning methods. To address this, we propose the utilization of looped transformer architecture and its associated training methodology, with the aim of incorporating iterative characteristics into the transformer architectures. Experimental results suggest that the looped transformer achieves performance comparable to the standard transformer in solving various data-fitting problems, while utilizing less than 10% of the parameter count.
Paper Structure (48 sections, 5 equations, 17 figures, 2 tables)

This paper contains 48 sections, 5 equations, 17 figures, 2 tables.

Figures (17)

  • Figure 1: How can a transformer be trained to learn an iterative learning algorithm? Here we consider the task of training a transformer to solve linear regression in context. The provided prompt $({{\bm{x}}_1, y_1, {\bm{x}}_2, y_2, \cdots, {\bm{x}}_k, y_k, {\bm{x}}_{test}})$ is fed into a decoder transformer. The objective is to reduce the squared loss between the predicted $\hat{y}_{\text{test}}$ based on this prompt, and the target value $f({\bm{x}}_{\text{test}})$. Garg2022WhatCT demonstrated that a decoder transformer can learn to solve linear regression, which potentially involves learning the approximation of the least squares solution. In this study, we aim to train a transformer to learn iterative learning algorithms. Our goal is to achieve performance on par with standard transformers but with fewer parameters. To this end, we introduce the looped transformer architecture and its accompanying training methodology.
  • Figure 2: The looped transformer can emulate iterative learning algorithms, offering performance comparable to standard transformers with reduced parameters. We train a looped transformer to solve linear regression in-context. (Left): While trained for 30 loop iterations, the looped transformer during inference achieves a stable fixed-point solution beyond the trained loop iterations. (Right): The trained looped transformer matches the performance of a standard 12-layer transformer and closely aligns with the least squares solver, while using only 1/12 of the transformer's parameters.
  • Figure 3: Test error for the linear regression problem using looped transformer, comparing models with default input injection to those without. Without input injection, the transformer's performance deteriorates beyond the trained loop iterations.
  • Figure 4: Evaluation of the looped transformer on in-context learning linear functions with different $b$ and $T$ during training ($b$ and $T$ are defined in Eq. \ref{['eq:target_func']}). The figure from left to right is trained with $T=5, 10, 15$, and different colors present different $b$ values (denoted in the legend). The solid lines of various colors depict how the looped transformer, trained with a specific value of $b$, performs as the loop iteration increases during inference. The corresponding dashed line represents the value of $b$.
  • Figure 5: Performance of transformer on linear functions with $d=10$ and $k=21$, when trained with different numbers of distinct prompts/functions.
  • ...and 12 more figures