Table of Contents
Fetching ...

Torque-Aware Momentum

Pranshu Malviya, Goncalo Mordido, Aristide Baratin, Reza Babanezhad Harikandeh, Gintare Karolina Dziugaite, Razvan Pascanu, Sarath Chandar

TL;DR

Torque-Aware Momentum (TAM) introduces a damping mechanism that modulates the influence of gradients based on their alignment with previous momentum, stabilizing updates and promoting robust exploration of the loss landscape. The approach yields TAM and its AdaTAM variant, with a learning-rate transfer heuristic to map SGDM settings to TAM settings, and demonstrably improves generalization across image classification and large-language-model fine-tuning, especially under distribution shifts. Empirical results show TAM-based methods outperform classical momentum baselines and remain competitive with, or surpass, adaptive optimizers in many scenarios, while AdaTAMW extends these benefits to AdamW-like settings with modest runtime overhead. TAM also acts as an effective warm-up strategy, reducing loss barriers and enhancing mode connectivity early in training, and shows promise for online/continual learning contexts; future work includes broader non-stationary environments and reinforcement learning applications.

Abstract

Efficiently exploring complex loss landscapes is key to the performance of deep neural networks. While momentum-based optimizers are widely used in state-of-the-art setups, classical momentum can still struggle with large, misaligned gradients, leading to oscillations. To address this, we propose Torque-Aware Momentum (TAM), which introduces a damping factor based on the angle between the new gradients and previous momentum, stabilizing the update direction during training. Empirical results show that TAM, which can be combined with both SGD and Adam, enhances exploration, handles distribution shifts more effectively, and improves generalization performance across various tasks, including image classification and large language model fine-tuning, when compared to classical momentum-based optimizers.

Torque-Aware Momentum

TL;DR

Torque-Aware Momentum (TAM) introduces a damping mechanism that modulates the influence of gradients based on their alignment with previous momentum, stabilizing updates and promoting robust exploration of the loss landscape. The approach yields TAM and its AdaTAM variant, with a learning-rate transfer heuristic to map SGDM settings to TAM settings, and demonstrably improves generalization across image classification and large-language-model fine-tuning, especially under distribution shifts. Empirical results show TAM-based methods outperform classical momentum baselines and remain competitive with, or surpass, adaptive optimizers in many scenarios, while AdaTAMW extends these benefits to AdamW-like settings with modest runtime overhead. TAM also acts as an effective warm-up strategy, reducing loss barriers and enhancing mode connectivity early in training, and shows promise for online/continual learning contexts; future work includes broader non-stationary environments and reinforcement learning applications.

Abstract

Efficiently exploring complex loss landscapes is key to the performance of deep neural networks. While momentum-based optimizers are widely used in state-of-the-art setups, classical momentum can still struggle with large, misaligned gradients, leading to oscillations. To address this, we propose Torque-Aware Momentum (TAM), which introduces a damping factor based on the angle between the new gradients and previous momentum, stabilizing the update direction during training. Empirical results show that TAM, which can be combined with both SGD and Adam, enhances exploration, handles distribution shifts more effectively, and improves generalization performance across various tasks, including image classification and large language model fine-tuning, when compared to classical momentum-based optimizers.

Paper Structure

This paper contains 28 sections, 9 equations, 11 figures, 16 tables, 1 algorithm.

Figures (11)

  • Figure 1: Comparing momentum updates obtained using SGDM and TAM for a given SGD trajectory. While TAM results in more stable directions pointing to a lower loss basin, SGDM has higher magnitude updates susceptible to misaligned gradients.
  • Figure 2: TAM controls update magnitude (red) based on the alignment between momentum and new gradients. The angle ($\alpha_1$, $\alpha_2$) between previous momentum (green) and new gradients (white) determines the magnitude of the update (red). When $g_1$ aligns well with $m_0$, the resulting momentum $m_1$ has a higher magnitude. In contrast, when the misalignment between $g_2$ and $m_1$ results in a smaller magnitude $m_2$.
  • Figure 3: Percentage improvement in the average scores of AdaTAMW compared to AdamW across different MTEB task categories for three types of models: BERT (left), DeBERTa (middle) and RoBERTa (right). The the y-axis labels indicate the model size ({Base, Large}) / MTEB task category ($7$ in total), and the number of fine-tuning epochs ({3, 5, 10}), covering 42 configurations in total. Overall, AdaTAMW achieve similar or better performance than AdamW in at least 28 configurations across all three model types.
  • Figure 4: Percentage of times when AdaTAMW performs better (or similar/better) than AdamW on various LLMs across 56 MTEB datasets. Green indicates that AdaTAMW achieves similar or better performance, while red indicates worse performance. Except for BERT models with $10$ epochs and RoBERTa-large, AdaTAMW performs similar/better in majority of the datasets.
  • Figure 5: Comparing online accuracy of TAM with SGDM and SGD on label flipping benchmark for training MLP with 2 hidden layers (first row) and 4 hidden layers (second row) after hyper-parameter search across effective learning rates for the following: (i) $40\%$ labels flipping, (ii) $80\%$ labels flipping, and (iii) $100\%$ labels flipping. Although TAM performs similar to SGDM for smoother shifts ($40\%$), it tends to outperform SGDM when distribution shifts are more drastic ($80\%$ and $100\%$).
  • ...and 6 more figures