Table of Contents
Fetching ...

Improving Line Search Methods for Large Scale Neural Network Training

Philip Kenneweg, Tristan Kenneweg, Barbara Hammer

TL;DR

This work improves the Armijo line search by integrating the momentum term from ADAM in its search direction, enabling efficient large-scale training, a task that was previously prone to failure using Armijo line search methods.

Abstract

In recent studies, line search methods have shown significant improvements in the performance of traditional stochastic gradient descent techniques, eliminating the need for a specific learning rate schedule. In this paper, we identify existing issues in state-of-the-art line search methods, propose enhancements, and rigorously evaluate their effectiveness. We test these methods on larger datasets and more complex data domains than before. Specifically, we improve the Armijo line search by integrating the momentum term from ADAM in its search direction, enabling efficient large-scale training, a task that was previously prone to failure using Armijo line search methods. Our optimization approach outperforms both the previous Armijo implementation and tuned learning rate schedules for Adam. Our evaluation focuses on Transformers and CNNs in the domains of NLP and image data. Our work is publicly available as a Python package, which provides a hyperparameter free Pytorch optimizer.

Improving Line Search Methods for Large Scale Neural Network Training

TL;DR

This work improves the Armijo line search by integrating the momentum term from ADAM in its search direction, enabling efficient large-scale training, a task that was previously prone to failure using Armijo line search methods.

Abstract

In recent studies, line search methods have shown significant improvements in the performance of traditional stochastic gradient descent techniques, eliminating the need for a specific learning rate schedule. In this paper, we identify existing issues in state-of-the-art line search methods, propose enhancements, and rigorously evaluate their effectiveness. We test these methods on larger datasets and more complex data domains than before. Specifically, we improve the Armijo line search by integrating the momentum term from ADAM in its search direction, enabling efficient large-scale training, a task that was previously prone to failure using Armijo line search methods. Our optimization approach outperforms both the previous Armijo implementation and tuned learning rate schedules for Adam. Our evaluation focuses on Transformers and CNNs in the domains of NLP and image data. Our work is publicly available as a Python package, which provides a hyperparameter free Pytorch optimizer.
Paper Structure (17 sections, 10 equations, 7 figures, 2 tables)

This paper contains 17 sections, 10 equations, 7 figures, 2 tables.

Figures (7)

  • Figure 1: Step size $\eta_k$ for large scale GPT2 training. We started with a fixed linear warmup of the step size until step 400. Afterwards, ADAM + SLS determined the step size.
  • Figure 2: Loss decrease (y-axis) vs step size (x-axis) on the QNLI dataset for a single batch. Note the logarithmic scaling of the x-axis. Red point indicates step size resulting of ALSALS. Green point indicates optimum loss decrease on single batch. The Area above the black line indicates where Eq \ref{['eq:armijoadam']} is satisfied.
  • Figure 3: Loss decrease (y-axis) vs step size (x-axis) on QNLI training for the last 10 consecutive batches with older runs fading. Note the logarithmic scaling of the x-axis. Red points indicate step size resulting of ALSALS. Green points indicate optimum loss decrease on single batch.
  • Figure 4: Loss decrease depicted as color on the QNLI dataset. Step $k$ is displayed on the y axis and step size $\eta_k$ on the x axis. Note the logarithmic scaling of the x-axis. Green line indicates step size resulting of ALSALS.
  • Figure 5: The top row displays the loss curves, while the bottom row presents the accuracy curves for the ResNet experiments on image datasets. Standard errors are indicated around each line, beginning from the second epoch. Accuracy was computed on the validation data, whereas loss was assessed on the training data.
  • ...and 2 more figures