Torque-Aware Momentum
Pranshu Malviya, Goncalo Mordido, Aristide Baratin, Reza Babanezhad Harikandeh, Gintare Karolina Dziugaite, Razvan Pascanu, Sarath Chandar
TL;DR
Torque-Aware Momentum (TAM) introduces a damping mechanism that modulates the influence of gradients based on their alignment with previous momentum, stabilizing updates and promoting robust exploration of the loss landscape. The approach yields TAM and its AdaTAM variant, with a learning-rate transfer heuristic to map SGDM settings to TAM settings, and demonstrably improves generalization across image classification and large-language-model fine-tuning, especially under distribution shifts. Empirical results show TAM-based methods outperform classical momentum baselines and remain competitive with, or surpass, adaptive optimizers in many scenarios, while AdaTAMW extends these benefits to AdamW-like settings with modest runtime overhead. TAM also acts as an effective warm-up strategy, reducing loss barriers and enhancing mode connectivity early in training, and shows promise for online/continual learning contexts; future work includes broader non-stationary environments and reinforcement learning applications.
Abstract
Efficiently exploring complex loss landscapes is key to the performance of deep neural networks. While momentum-based optimizers are widely used in state-of-the-art setups, classical momentum can still struggle with large, misaligned gradients, leading to oscillations. To address this, we propose Torque-Aware Momentum (TAM), which introduces a damping factor based on the angle between the new gradients and previous momentum, stabilizing the update direction during training. Empirical results show that TAM, which can be combined with both SGD and Adam, enhances exploration, handles distribution shifts more effectively, and improves generalization performance across various tasks, including image classification and large language model fine-tuning, when compared to classical momentum-based optimizers.
