Zero-th Order Algorithm for Softmax Attention Optimization
Yichuan Deng, Zhihang Li, Sridhar Mahadevan, Zhao Song
TL;DR
The work introduces a zeroth-order optimization scheme for softmax attention, enabling gradient-free optimization of softmax-based attention mechanisms in large language models. It formalizes a softmax loss L_exp and proves smoothness and strong convexity properties under stated conditions, enabling convergence analyses. By deploying SPSA-based gradient estimates within a gradient-descent framework, the paper derives per-iteration loss decrease bounds and a global convergence theorem, with iteration complexity scaling in terms of rank, conditioning, and batch parameters. The results provide theoretical backing for efficient optimization of softmax attention in large-scale models, offering a path toward faster, gradient-free tuning of attention components.
Abstract
Large language models (LLMs) have brought about significant transformations in human society. Among the crucial computations in LLMs, the softmax unit holds great importance. Its helps the model generating a probability distribution on potential subsequent words or phrases, considering a series of input words. By utilizing this distribution, the model selects the most probable next word or phrase, based on the assigned probabilities. The softmax unit assumes a vital function in LLM training as it facilitates learning from data through the adjustment of neural network weights and biases. With the development of the size of LLMs, computing the gradient becomes expensive. However, Zero-th Order method can approximately compute the gradient with only forward passes. In this paper, we present a Zero-th Order algorithm specifically tailored for Softmax optimization. We demonstrate the convergence of our algorithm, highlighting its effectiveness in efficiently computing gradients for large-scale LLMs. By leveraging the Zeroth-Order method, our work contributes to the advancement of optimization techniques in the context of complex language models.
