Table of Contents
Fetching ...

PyLO: Towards Accessible Learned Optimizers in PyTorch

Paul Janson, Benjamin Therien, Quentin Anthony, Xiaolong Huang, Abhinav Moudgil, Eugene Belilovsky

TL;DR

PyLO addresses the practical barriers to deploying learned optimizers by delivering a PyTorch-friendly, CUDA-accelerated library with HuggingFace Hub integration and a decoupled design that separates meta-training from deployment. Its two fused-kernel CUDA implementation dramatically reduces LO overhead and scales to large models, enabling competitive performance on Vision Transformer and GPT-2 pretraining tasks. Extensive benchmarks show substantial speedups over naive implementations and favorable comparisons to JAX baselines, while distributed optimizer steps further reduce overhead. The work demonstrates real-world applicability and provides a foundation for community-driven advancement of learned optimizers in modern ML systems.

Abstract

Learned optimizers have been an active research topic over the past decade, with increasing progress toward practical, general-purpose optimizers that can serve as drop-in replacements for widely used methods like Adam. However, recent advances -- such as VeLO, which was meta-trained for 4000 TPU-months -- remain largely inaccessible to the broader community, in part due to their reliance on JAX and the absence of user-friendly packages for applying the optimizers after meta-training. To address this gap, we introduce PyLO, a PyTorch-based library that brings learned optimizers to the broader machine learning community through familiar, widely adopted workflows. Unlike prior work focused on synthetic or convex tasks, our emphasis is on applying learned optimization to real-world large-scale pre-training tasks. Our release includes a CUDA-accelerated version of the small_fc_lopt learned optimizer architecture from (Metz et al., 2022a), delivering substantial speedups -- from 39.36 to 205.59 samples/sec throughput for training ViT B/16 with batch size 32. PyLO also allows us to easily combine learned optimizers with existing optimization tools such as learning rate schedules and weight decay. When doing so, we find that learned optimizers can substantially benefit. Our code is available at https://github.com/Belilovsky-Lab/pylo

PyLO: Towards Accessible Learned Optimizers in PyTorch

TL;DR

PyLO addresses the practical barriers to deploying learned optimizers by delivering a PyTorch-friendly, CUDA-accelerated library with HuggingFace Hub integration and a decoupled design that separates meta-training from deployment. Its two fused-kernel CUDA implementation dramatically reduces LO overhead and scales to large models, enabling competitive performance on Vision Transformer and GPT-2 pretraining tasks. Extensive benchmarks show substantial speedups over naive implementations and favorable comparisons to JAX baselines, while distributed optimizer steps further reduce overhead. The work demonstrates real-world applicability and provides a foundation for community-driven advancement of learned optimizers in modern ML systems.

Abstract

Learned optimizers have been an active research topic over the past decade, with increasing progress toward practical, general-purpose optimizers that can serve as drop-in replacements for widely used methods like Adam. However, recent advances -- such as VeLO, which was meta-trained for 4000 TPU-months -- remain largely inaccessible to the broader community, in part due to their reliance on JAX and the absence of user-friendly packages for applying the optimizers after meta-training. To address this gap, we introduce PyLO, a PyTorch-based library that brings learned optimizers to the broader machine learning community through familiar, widely adopted workflows. Unlike prior work focused on synthetic or convex tasks, our emphasis is on applying learned optimization to real-world large-scale pre-training tasks. Our release includes a CUDA-accelerated version of the small_fc_lopt learned optimizer architecture from (Metz et al., 2022a), delivering substantial speedups -- from 39.36 to 205.59 samples/sec throughput for training ViT B/16 with batch size 32. PyLO also allows us to easily combine learned optimizers with existing optimization tools such as learning rate schedules and weight decay. When doing so, we find that learned optimizers can substantially benefit. Our code is available at https://github.com/Belilovsky-Lab/pylo

Paper Structure

This paper contains 24 sections, 1 equation, 17 figures, 11 tables.

Figures (17)

  • Figure 1: Training step timing breakdown for ViT-B/16 on a single A100 GPU. We measure the forward, backward, and optimizer times per training step. We observe that the CUDA-accelerated learned optimizer steps (cyan, green) show substantial improvements over the naive implementations (yellow, red). In all cases, as the batch size is increased, the relative overhead of the optimizer shrinks.
  • Figure 2: PyLO: simplifies the integration of learned optimizers into standard machine learning workflows. By addressing key usability challenges, PyLO provides seamless access to meta-learned optimization techniques through three core features: (1) automatic weight loading from Hugging Face Hub, (2) a familiar PyTorch-style optimizer interface, and (3) accelerated CUDA kernel support. The library bridges the gap between advanced meta-learning research (google/learned_optimizationmetz2022practical) and practical machine-learning applications, enabling researchers and practitioners to easily leverage state-of-the-art learned optimization techniques in PyTorch.
  • Figure 3: Overview of the Learned Optimizer (LO) update mechanism: The figure depicts the computation pipeline used by learned optimizers such as small_fc_lopt and VeLO. Model parameters ($\mathbf{\theta}$), gradients ($\mathbf{g}$), momentum (m), and second-moment accumulators (v, row & column factors of v) are combined within construct_features() to form a rich set of derivative features, including elementwise interactions such as $\mathbf{g} \times \text{momentum}$ and $\text{momentum} \times \text{factors}$ (see Appendix \ref{['apdx:features_learned_opt']} for a complete description). These features are reduced over the parameter dimensions ($m \times n$) in compute_squared_average() to obtain normalization statistics. In feature_normalization(), features are normalized before being processed by the Learned Optimizer, which predicts an update direction and magnitude in apply_lo(), and this predicted update is then applied to yield the updated parameters.
  • Figure 4: Comparison of CUDA kernel execution characteristics for the learned optimizers small_fc_lopt and VeLO, shown for both naïve and fused implementations applied to a single layer MLP (shape [1000x1000]) with no bias. (Left) The total cumulative kernel execution time (blue) and the number of kernel launches (orange) reveal that the naïve implementations spend substantially more time executing a far greater number of small kernels, due to performing each arithmetic or reduction step as an independent operation. In contrast, the fused CUDA versions combine these operations into fewer, more computationally dense kernels, reducing both launch overhead and memory bandwidth pressure. (Right) Breakdown of cumulative kernel time by kernel type for each optimizer shows that the fused versions (bolded) consolidate many repetitive operations into specialized fused kernels. This pattern is consistent across both VeLO and small_fc_lopt, demonstrating that kernel fusion is broadly beneficial for improving learned optimizer efficiency.
  • Figure 5: Step Time Scaling of Learned Optimizers. We present a comparison of optimizer step time between our custom CUDA implementation and the original JAX versions of small_fc_lopt and VeLO, evaluated during the training of GPT-2 style transformer models across a range of model sizes. The results show that our CUDA implementation not only achieves substantially lower step times but also maintains this advantage as model size increases, enabling more efficient scaling to larger architectures.
  • ...and 12 more figures