Table of Contents
Fetching ...

A Trainable Optimizer

Ruiqi Wang, Diego Klabjan

TL;DR

This work introduces Trainable Optimizer (TO), a framework that co-trains the gradient estimator and model weights to realize variance reduction and SGD-like convergence. By modeling the gradient estimator as a pseudo-linear function $\widehat{G}_t = A_t w_t + b_t$ and updating $A_t,b_t$ via a squared-error loss, the authors prove $\mathcal{O}(1/t)$ convergence for strongly convex losses and a vanishing gradient-approximation variance at the same rate; non-convex guarantees yield a $\mathcal{O}(1/\log T)$ rate for the gradient error. To scale to large models, Diagonal TO and Rank-One TO reduce memory to $\mathcal{O}(d)$ while preserving convergence, with momentum emerging as a special case of Full-TO. Empirically, TO variants outperform ADAM on strongly convex problems and show faster convergence and better final performance on challenging non-convex tasks, including ResNet training and LLM fine-tuning, highlighting practical benefits for accelerated optimization.

Abstract

The concept of learning to optimize involves utilizing a trainable optimization strategy rather than relying on manually defined full gradient estimations such as ADAM. We present a framework that jointly trains the full gradient estimator and the trainable weights of the model. Specifically, we prove that pseudo-linear TO (Trainable Optimizer), a linear approximation of the full gradient, matches SGD's convergence rate while effectively reducing variance. Pseudo-linear TO incurs negligible computational overhead, requiring only minimal additional tensor multiplications. To further improve computational efficiency, we introduce two simplified variants of Pseudo-linear TO. Experiments demonstrate that TO methods converge faster than benchmark algorithms (e.g., ADAM) in both strongly convex and non-convex settings, and fine tuning of an LLM.

A Trainable Optimizer

TL;DR

This work introduces Trainable Optimizer (TO), a framework that co-trains the gradient estimator and model weights to realize variance reduction and SGD-like convergence. By modeling the gradient estimator as a pseudo-linear function and updating via a squared-error loss, the authors prove convergence for strongly convex losses and a vanishing gradient-approximation variance at the same rate; non-convex guarantees yield a rate for the gradient error. To scale to large models, Diagonal TO and Rank-One TO reduce memory to while preserving convergence, with momentum emerging as a special case of Full-TO. Empirically, TO variants outperform ADAM on strongly convex problems and show faster convergence and better final performance on challenging non-convex tasks, including ResNet training and LLM fine-tuning, highlighting practical benefits for accelerated optimization.

Abstract

The concept of learning to optimize involves utilizing a trainable optimization strategy rather than relying on manually defined full gradient estimations such as ADAM. We present a framework that jointly trains the full gradient estimator and the trainable weights of the model. Specifically, we prove that pseudo-linear TO (Trainable Optimizer), a linear approximation of the full gradient, matches SGD's convergence rate while effectively reducing variance. Pseudo-linear TO incurs negligible computational overhead, requiring only minimal additional tensor multiplications. To further improve computational efficiency, we introduce two simplified variants of Pseudo-linear TO. Experiments demonstrate that TO methods converge faster than benchmark algorithms (e.g., ADAM) in both strongly convex and non-convex settings, and fine tuning of an LLM.

Paper Structure

This paper contains 23 sections, 6 theorems, 55 equations, 6 figures, 7 tables, 3 algorithms.

Key Result

Proposition 3.2

Given $\|w_t\|\leq D_w$, there exists $0<D_G<+\infty$ such that for all $t$, $\|g_t\|_2 \leq D_G$.

Figures (6)

  • Figure 1: Training loss curves for logistic regression ($\lambda=0$) on different datasets
  • Figure 2: Training loss curves for non-convex tasks
  • Figure 3: Validation loss curve for Llama7b-alpaca task
  • Figure 4: Training loss curves for logistic regression tasks with $\lambda=0$. Standardization is based on minimum=0.96 and maximum=1.31 in the left figure, and minimum=2.18 and maximum=2.99 in the right figure.
  • Figure 5: Standard deviation of training loss for logistic regression tasks with $\lambda=0$. Standardization is based on minimum=$1.31\times10^{-3}$ and maximum=$4.86\times 10^{-2}$ in Figure (a), minimum=$2.78\times 10^{-4}$ and maximum=$1.36\times 10^{-2}$ in Figure (b), minimum=$1.00\times 10^{-2}$ and maximum=$2.11\times 10^{-2}$ in Figure (c), and minimum=$1.48\times 10^{-4}$ and maximum=$6.48\times 10^{-3}$ in Figure (d).
  • ...and 1 more figures

Theorems & Definitions (12)

  • Proposition 3.2
  • Proposition 3.3
  • Theorem 3.4
  • Theorem 3.5
  • Lemma 1.1
  • proof
  • Lemma 1.2
  • proof
  • proof
  • proof
  • ...and 2 more