Table of Contents
Fetching ...

On the Complexity of Learning Sparse Functions with Statistical and Gradient Queries

Nirmit Joshi, Theodor Misiakiewicz, Nathan Srebro

TL;DR

Evidence is provided that Differentiable Learning Queries can indeed capture learning with (stochastic) gradient descent by showing it correctly describes the complexity of learning with a two-layer neural network in the mean field regime and linear scaling.

Abstract

The goal of this paper is to investigate the complexity of gradient algorithms when learning sparse functions (juntas). We introduce a type of Statistical Queries ($\mathsf{SQ}$), which we call Differentiable Learning Queries ($\mathsf{DLQ}$), to model gradient queries on a specified loss with respect to an arbitrary model. We provide a tight characterization of the query complexity of $\mathsf{DLQ}$ for learning the support of a sparse function over generic product distributions. This complexity crucially depends on the loss function. For the squared loss, $\mathsf{DLQ}$ matches the complexity of Correlation Statistical Queries $(\mathsf{CSQ})$--potentially much worse than $\mathsf{SQ}$. But for other simple loss functions, including the $\ell_1$ loss, $\mathsf{DLQ}$ always achieves the same complexity as $\mathsf{SQ}$. We also provide evidence that $\mathsf{DLQ}$ can indeed capture learning with (stochastic) gradient descent by showing it correctly describes the complexity of learning with a two-layer neural network in the mean field regime and linear scaling.

On the Complexity of Learning Sparse Functions with Statistical and Gradient Queries

TL;DR

Evidence is provided that Differentiable Learning Queries can indeed capture learning with (stochastic) gradient descent by showing it correctly describes the complexity of learning with a two-layer neural network in the mean field regime and linear scaling.

Abstract

The goal of this paper is to investigate the complexity of gradient algorithms when learning sparse functions (juntas). We introduce a type of Statistical Queries (), which we call Differentiable Learning Queries (), to model gradient queries on a specified loss with respect to an arbitrary model. We provide a tight characterization of the query complexity of for learning the support of a sparse function over generic product distributions. This complexity crucially depends on the loss function. For the squared loss, matches the complexity of Correlation Statistical Queries --potentially much worse than . But for other simple loss functions, including the loss, always achieves the same complexity as . We also provide evidence that can indeed capture learning with (stochastic) gradient descent by showing it correctly describes the complexity of learning with a two-layer neural network in the mean field regime and linear scaling.
Paper Structure (49 sections, 24 theorems, 168 equations, 1 figure, 1 table)

This paper contains 49 sections, 24 theorems, 168 equations, 1 figure, 1 table.

Key Result

Theorem 5.1

For any junta problem $\mu$ and any loss $\ell$, there exists $C>c>0$ (that depend on ${P}$,$\mu$ and the loss, but not on $d$), such that for query types ${\sf A} \in \{{\sf SQ},{\sf CSQ},{\sf DLQ}_\ell\}$ with corresponding test function sets $\Psi_{\sf A}$ as defined in eq:Psi:

Figures (1)

  • Figure 1: The function $h_*({\boldsymbol z})$ in \ref{['eq:example-simulation']} has ${\sf Leap}_{{\sf CSQ}}=3$ but ${\sf Leap}_{{\sf SQ}}=1$. For the squared loss (left plot), \ref{['eq:DF-dynamics']} remains stuck at initialization (no learning), and to escape the saddle, SGD requires a number of iterations that increases faster than $O(d)$. For the absolute loss (center plot) or the other loss (right plot), we have ${\sf Leap}_{{\sf DLQ}_{\ell}}={\sf Leap}_{{\sf SQ}}=1$, and the SGD dynamics learns in $\Theta(d)$ steps and \ref{['eq:DF-dynamics']} learns in $O(1)$ continuous time.

Theorems & Definitions (53)

  • Remark 3.1
  • Remark 3.2
  • Definition 1
  • Remark 4.1
  • Definition 2: Detectable Subsets
  • Theorem 5.1
  • Remark 5.2
  • Remark 5.3
  • Definition 3
  • Proposition 6.1: ${\sf SQ}$ versus ${\sf CSQ}$
  • ...and 43 more