Toward generalizable learning of all (linear) first-order methods via memory augmented Transformers
Sanchayan Dutta, Suvrit Sra
TL;DR
This work demonstrates that memory-augmented Transformers (Memformers) can implement and learn the full class of linear first-order methods (LFOMs), including gradient descent, momentum methods, and conjugate gradient descent, by maintaining and updating memory across layers. The authors provide a theoretical framework showing how Memformers can realize LFOM iterations and CGD-like updates, and they validate these claims with empirical experiments on random linear regression tasks, where Memformers achieve competitive or superior performance under various preconditioning and architectural setups. They further enhance robustness and adaptation through multi-head attention and mixture-of-experts, enabling test-time adaptation to distribution shifts (OOD) and demonstrating LFOMs as learnable meta-algorithms with generalization guarantees. Overall, the paper positions Memformers as versatile, learnable optimizers capable of capturing a wide range of first-order methods and adapting to changing data distributions, with potential implications for accelerating optimization in large-scale models and in-context learning scenarios.
Abstract
We show that memory-augmented Transformers can implement the entire class of linear first-order methods (LFOMs), a class that contains gradient descent (GD) and more advanced methods such as conjugate gradient descent (CGD), momentum methods and all other variants that linearly combine past gradients. Building on prior work that studies how Transformers simulate GD, we provide theoretical and empirical evidence that memory-augmented Transformers can learn more advanced algorithms. We then take a first step toward turning the learned algorithms into actually usable methods by developing a mixture-of-experts (MoE) approach for test-time adaptation to out-of-distribution (OOD) samples. Lastly, we show that LFOMs can themselves be treated as learnable algorithms, whose parameters can be learned from data to attain strong performance.
