Solving Regularized Exp, Cosh and Sinh Regression Problems
Zhihang Li, Zhao Song, Tianyi Zhou
TL;DR
The paper addresses the efficiency of solving regularized exponential regression problems motivated by attention in large language models. By proving convexity and Lipschitz properties for the regualrized loss with $f\in\{\exp,\cosh,\sinh\}$, it lays a foundation for a fast inexact Newton method. The authors propose a randomized, input-sparsity-time algorithm with iteration complexity $O(\log(\|x_0-x^*\|_2/\epsilon))$ and per-iteration cost $\widetilde{O}(\mathrm{nnz}(A)+d^{\omega})$, achieving high-probability convergence to $x^*$ within error $\epsilon$. This approach offers a theoretically grounded, scalable path to solving attention-inspired regression problems in sparse regimes, with potential impact on efficiency of attention computations in practice.
Abstract
In modern machine learning, attention computation is a fundamental task for training large language models such as Transformer, GPT-4 and ChatGPT. In this work, we study exponential regression problem which is inspired by the softmax/exp unit in the attention mechanism in large language models. The standard exponential regression is non-convex. We study the regularization version of exponential regression problem which is a convex problem. We use approximate newton method to solve in input sparsity time. Formally, in this problem, one is given matrix $A \in \mathbb{R}^{n \times d}$, $b \in \mathbb{R}^n$, $w \in \mathbb{R}^n$ and any of functions $\exp, \cosh$ and $\sinh$ denoted as $f$. The goal is to find the optimal $x$ that minimize $ 0.5 \| f(Ax) - b \|_2^2 + 0.5 \| \mathrm{diag}(w) A x \|_2^2$. The straightforward method is to use the naive Newton's method. Let $\mathrm{nnz}(A)$ denote the number of non-zeros entries in matrix $A$. Let $ω$ denote the exponent of matrix multiplication. Currently, $ω\approx 2.373$. Let $ε$ denote the accuracy error. In this paper, we make use of the input sparsity and purpose an algorithm that use $\log ( \|x_0 - x^*\|_2 / ε)$ iterations and $\widetilde{O}(\mathrm{nnz}(A) + d^ω )$ per iteration time to solve the problem.
