Learning the Positions in CountSketch
Yi Li, Honghao Lin, Simin Liu, Ali Vakilian, David P. Woodruff
TL;DR
The paper shows that CountSketch can be substantially improved by learning both the positions and values of nonzero entries, enabling faster and more accurate solutions for low-rank approximation and Hessian-based optimization. It introduces a two-stage learning framework (locating nonzeros, then optimizing their values) and proposes a scalable greedy method plus two fast designs tailored to LRA and second-order optimization, each with data-dependent guarantees. Empirical results across multiple datasets demonstrate large reductions in error (up to roughly 70% over classical sketches and 30% over prior learned-sketch approaches) and substantial training-time savings, including effective few-shot scenarios. This work advances data-dependent sketching by delivering application-specific, provably improving sketches that enhance both speed and accuracy in large-scale linear algebra and optimization tasks.
Abstract
We consider sketching algorithms which first compress data by multiplication with a random sketch matrix, and then apply the sketch to quickly solve an optimization problem, e.g., low-rank approximation and regression. In the learning-based sketching paradigm proposed by~\cite{indyk2019learning}, the sketch matrix is found by choosing a random sparse matrix, e.g., CountSketch, and then the values of its non-zero entries are updated by running gradient descent on a training data set. Despite the growing body of work on this paradigm, a noticeable omission is that the locations of the non-zero entries of previous algorithms were fixed, and only their values were learned. In this work, we propose the first learning-based algorithms that also optimize the locations of the non-zero entries. Our first proposed algorithm is based on a greedy algorithm. However, one drawback of the greedy algorithm is its slower training time. We fix this issue and propose approaches for learning a sketching matrix for both low-rank approximation and Hessian approximation for second order optimization. The latter is helpful for a range of constrained optimization problems, such as LASSO and matrix estimation with a nuclear norm constraint. Both approaches achieve good accuracy with a fast running time. Moreover, our experiments suggest that our algorithm can still reduce the error significantly even if we only have a very limited number of training matrices.
