FFT-based Dynamic Subspace Selection for Low-Rank Adaptive Optimization of Large Language Models
Ionut-Vlad Modoranu, Mher Safaryan, Erik Schultheis, Max Ryabinin, Artem Chumachenko, Dan Alistarh
TL;DR
This work tackles the memory and compute bottlenecks of adaptive gradient optimizers for large language models by replacing costly $SVD$/$QR$-based projections with a fixed orthogonal $Q$ from the Discrete Cosine Transform (DCT) and a dynamic column-selection strategy. It introduces two algorithms: Trion, which replaces Power-Iteration in Dion with DCT-based column selection and Newton-Schulz orthogonalization, and DCT-AdamW, a low-rank AdamW variant using DCT projections with optional error feedback. Across pretraining and finetuning tasks, the approach matches or exceeds the performance of traditional low-rank methods while achieving memory reductions and faster runtimes, with reported speedups up to $25\%$ and rank-independent running time. Theoretical guarantees show that norm-based column selection is contractive and that the DCT provides a natural linear approximation to the gradient eigenbasis, supporting the practical effectiveness of the method for scalable, memory-efficient optimization of LLMs.
Abstract
Low-rank optimization has emerged as a promising direction in training large language models (LLMs) to improve running time and reduce the memory usage of adaptive optimizers by constraining learning to a lower-dimensional space. Prior work typically projects gradients of linear layers using approaches based on Singular Value Decomposition (SVD) or QR-decomposition. Applying these techniques individually to each layer in large models is computationally expensive and incurs additional memory costs due to storing the projection matrices. In this work, we propose a computationally efficient and conceptually simple, two-step procedure to approximate SVD/QR-based gradient projections into lower-dimensional spaces by using a predefined orthogonal matrix of the Discrete Cosine Transform (DCT). We dynamically select columns from the DCT matrix based on their alignment with the gradient of each layer. The effective projection matrices are obtained via a simple matmul with the DCT matrix in $O(n^3)$ time, followed by a lightweight sorting step to identify the most relevant basis vectors. For large layers, DCT can be computed via Makhoul's $N$-point algorithm based on Fast Fourier Transform (FFT) in $O(n^2 \log(n))$ time. Due to the predefined nature of the orthogonal bases, they are computed once at the start of training. Our numerical experiments on both pre-training and fine-tuning tasks demonstrate the effectiveness of our dual strategy in approximating optimal low-rank projections, obtaining an approach with rank-independent running time that matches the performance of costly SVD/QR-based methods while achieving faster runtime and reduced memory usage by up to $25\%$ across different model sizes. Our code is available at \href{https://github.com/IST-DASLab/ISTA-DASLab-Optimizers}{\texttt{https://github.com/IST-DASLab/ISTA-DASLab-Optimizers}}.
