Table of Contents
Fetching ...

Learning the Positions in CountSketch

Yi Li, Honghao Lin, Simin Liu, Ali Vakilian, David P. Woodruff

TL;DR

The paper shows that CountSketch can be substantially improved by learning both the positions and values of nonzero entries, enabling faster and more accurate solutions for low-rank approximation and Hessian-based optimization. It introduces a two-stage learning framework (locating nonzeros, then optimizing their values) and proposes a scalable greedy method plus two fast designs tailored to LRA and second-order optimization, each with data-dependent guarantees. Empirical results across multiple datasets demonstrate large reductions in error (up to roughly 70% over classical sketches and 30% over prior learned-sketch approaches) and substantial training-time savings, including effective few-shot scenarios. This work advances data-dependent sketching by delivering application-specific, provably improving sketches that enhance both speed and accuracy in large-scale linear algebra and optimization tasks.

Abstract

We consider sketching algorithms which first compress data by multiplication with a random sketch matrix, and then apply the sketch to quickly solve an optimization problem, e.g., low-rank approximation and regression. In the learning-based sketching paradigm proposed by~\cite{indyk2019learning}, the sketch matrix is found by choosing a random sparse matrix, e.g., CountSketch, and then the values of its non-zero entries are updated by running gradient descent on a training data set. Despite the growing body of work on this paradigm, a noticeable omission is that the locations of the non-zero entries of previous algorithms were fixed, and only their values were learned. In this work, we propose the first learning-based algorithms that also optimize the locations of the non-zero entries. Our first proposed algorithm is based on a greedy algorithm. However, one drawback of the greedy algorithm is its slower training time. We fix this issue and propose approaches for learning a sketching matrix for both low-rank approximation and Hessian approximation for second order optimization. The latter is helpful for a range of constrained optimization problems, such as LASSO and matrix estimation with a nuclear norm constraint. Both approaches achieve good accuracy with a fast running time. Moreover, our experiments suggest that our algorithm can still reduce the error significantly even if we only have a very limited number of training matrices.

Learning the Positions in CountSketch

TL;DR

The paper shows that CountSketch can be substantially improved by learning both the positions and values of nonzero entries, enabling faster and more accurate solutions for low-rank approximation and Hessian-based optimization. It introduces a two-stage learning framework (locating nonzeros, then optimizing their values) and proposes a scalable greedy method plus two fast designs tailored to LRA and second-order optimization, each with data-dependent guarantees. Empirical results across multiple datasets demonstrate large reductions in error (up to roughly 70% over classical sketches and 30% over prior learned-sketch approaches) and substantial training-time savings, including effective few-shot scenarios. This work advances data-dependent sketching by delivering application-specific, provably improving sketches that enhance both speed and accuracy in large-scale linear algebra and optimization tasks.

Abstract

We consider sketching algorithms which first compress data by multiplication with a random sketch matrix, and then apply the sketch to quickly solve an optimization problem, e.g., low-rank approximation and regression. In the learning-based sketching paradigm proposed by~\cite{indyk2019learning}, the sketch matrix is found by choosing a random sparse matrix, e.g., CountSketch, and then the values of its non-zero entries are updated by running gradient descent on a training data set. Despite the growing body of work on this paradigm, a noticeable omission is that the locations of the non-zero entries of previous algorithms were fixed, and only their values were learned. In this work, we propose the first learning-based algorithms that also optimize the locations of the non-zero entries. Our first proposed algorithm is based on a greedy algorithm. However, one drawback of the greedy algorithm is its slower training time. We fix this issue and propose approaches for learning a sketching matrix for both low-rank approximation and Hessian approximation for second order optimization. The latter is helpful for a range of constrained optimization problems, such as LASSO and matrix estimation with a nuclear norm constraint. Both approaches achieve good accuracy with a fast running time. Moreover, our experiments suggest that our algorithm can still reduce the error significantly even if we only have a very limited number of training matrices.
Paper Structure (44 sections, 28 theorems, 95 equations, 6 figures, 7 tables, 6 algorithms)

This paper contains 44 sections, 28 theorems, 95 equations, 6 figures, 7 tables, 6 algorithms.

Key Result

Theorem 5.1

Let $S \in \mathbb{R}^{2m \times n}$ be given by concatenating the sketching matrices $S_1, S_2$ computed by Algorithm alg:new_alg with input $A$ induced by $\mathsf{Tr}$ and let $B \in \mathbb{R}^{n \times d}$. Then with probability at least $1 - \delta$, we have $\min_{\mathop{\mathrm{rank-\mathno

Figures (6)

  • Figure 8.1: Test error of LASSO in Electric dataset.
  • Figure A.1: Total variation distance between train and test matrices. left: Logo, middle: friend, right: Hyper.
  • Figure B.1: Test error of LASSO on Electric dataset
  • Figure B.2: Test error of matrix estimation with a nuclear norm constraint on the Tunnel dataset
  • Figure B.3: Test error of the subroutine in fast regression on Electric dataset.
  • ...and 1 more figures

Theorems & Definitions (64)

  • Theorem 5.1
  • Theorem 6.1
  • Lemma 6.2
  • Definition C.1: Affine Embedding
  • Lemma C.2: clarksonwoodruff; Lemma 40
  • Lemma C.3
  • Lemma C.4: sarlos2006improvedclarksonwoodruff
  • Lemma C.5
  • proof
  • Lemma C.6: avron2016sharper; Lemma 27
  • ...and 54 more