A Trainable Optimizer
Ruiqi Wang, Diego Klabjan
TL;DR
This work introduces Trainable Optimizer (TO), a framework that co-trains the gradient estimator and model weights to realize variance reduction and SGD-like convergence. By modeling the gradient estimator as a pseudo-linear function $\widehat{G}_t = A_t w_t + b_t$ and updating $A_t,b_t$ via a squared-error loss, the authors prove $\mathcal{O}(1/t)$ convergence for strongly convex losses and a vanishing gradient-approximation variance at the same rate; non-convex guarantees yield a $\mathcal{O}(1/\log T)$ rate for the gradient error. To scale to large models, Diagonal TO and Rank-One TO reduce memory to $\mathcal{O}(d)$ while preserving convergence, with momentum emerging as a special case of Full-TO. Empirically, TO variants outperform ADAM on strongly convex problems and show faster convergence and better final performance on challenging non-convex tasks, including ResNet training and LLM fine-tuning, highlighting practical benefits for accelerated optimization.
Abstract
The concept of learning to optimize involves utilizing a trainable optimization strategy rather than relying on manually defined full gradient estimations such as ADAM. We present a framework that jointly trains the full gradient estimator and the trainable weights of the model. Specifically, we prove that pseudo-linear TO (Trainable Optimizer), a linear approximation of the full gradient, matches SGD's convergence rate while effectively reducing variance. Pseudo-linear TO incurs negligible computational overhead, requiring only minimal additional tensor multiplications. To further improve computational efficiency, we introduce two simplified variants of Pseudo-linear TO. Experiments demonstrate that TO methods converge faster than benchmark algorithms (e.g., ADAM) in both strongly convex and non-convex settings, and fine tuning of an LLM.
