Table of Contents
Fetching ...

Gradient Descent Fails to Learn High-frequency Functions and Modular Arithmetic

Rustem Takhanov, Maxat Tezekbayev, Artur Pak, Arman Bolatov, Zhenisbek Assylbekov

TL;DR

This work investigates why gradient-based learning struggles to learn high-frequency and modular arithmetic functions. By modeling targets as $h_a(x)=\psi(ax)$ with a 1-periodic $\psi$ of bounded variation and its discrete analogue on ${\mathbb Z}_p$, the authors bound the gradient variance as $Var(\mathcal{H}_A, \mathbf{w}) \in \tilde{\mathcal{O}}(1/\sqrt{A})$ and $Var(\mathring{\mathcal{H}}_p, \mathbf{w}) \in \tilde{\mathcal{O}}(1/\sqrt{p})$, showing barren plateaus for large frequency or base. They connect these gradient-variance bounds to SQ-dimension, proving lower bounds that imply hardness for SQ algorithms as well as gradient-based methods. The analysis combines Boas–Bellman inequalities with ergodic-translation techniques on a 2D torus and leverages discrete Fourier transforms to handle $p$-periodic functions. Empirical verifications on real waves and modular multiplication corroborate the theory, illustrating the practical difficulty of learning such tasks with SGD-like optimization and standard neural networks. Overall, the results delineate fundamental limits on gradient-based learnability for high-frequency and modular arithmetic targets and suggest deeper connections to SQ hardness and grokking phenomena.

Abstract

Classes of target functions containing a large number of approximately orthogonal elements are known to be hard to learn by the Statistical Query algorithms. Recently this classical fact re-emerged in a theory of gradient-based optimization of neural networks. In the novel framework, the hardness of a class is usually quantified by the variance of the gradient with respect to a random choice of a target function. A set of functions of the form $x\to ax \bmod p$, where $a$ is taken from ${\mathbb Z}_p$, has attracted some attention from deep learning theorists and cryptographers recently. This class can be understood as a subset of $p$-periodic functions on ${\mathbb Z}$ and is tightly connected with a class of high-frequency periodic functions on the real line. We present a mathematical analysis of limitations and challenges associated with using gradient-based learning techniques to train a high-frequency periodic function or modular multiplication from examples. We highlight that the variance of the gradient is negligibly small in both cases when either a frequency or the prime base $p$ is large. This in turn prevents such a learning algorithm from being successful.

Gradient Descent Fails to Learn High-frequency Functions and Modular Arithmetic

TL;DR

This work investigates why gradient-based learning struggles to learn high-frequency and modular arithmetic functions. By modeling targets as with a 1-periodic of bounded variation and its discrete analogue on , the authors bound the gradient variance as and , showing barren plateaus for large frequency or base. They connect these gradient-variance bounds to SQ-dimension, proving lower bounds that imply hardness for SQ algorithms as well as gradient-based methods. The analysis combines Boas–Bellman inequalities with ergodic-translation techniques on a 2D torus and leverages discrete Fourier transforms to handle -periodic functions. Empirical verifications on real waves and modular multiplication corroborate the theory, illustrating the practical difficulty of learning such tasks with SGD-like optimization and standard neural networks. Overall, the results delineate fundamental limits on gradient-based learnability for high-frequency and modular arithmetic targets and suggest deeper connections to SQ hardness and grokking phenomena.

Abstract

Classes of target functions containing a large number of approximately orthogonal elements are known to be hard to learn by the Statistical Query algorithms. Recently this classical fact re-emerged in a theory of gradient-based optimization of neural networks. In the novel framework, the hardness of a class is usually quantified by the variance of the gradient with respect to a random choice of a target function. A set of functions of the form , where is taken from , has attracted some attention from deep learning theorists and cryptographers recently. This class can be understood as a subset of -periodic functions on and is tightly connected with a class of high-frequency periodic functions on the real line. We present a mathematical analysis of limitations and challenges associated with using gradient-based learning techniques to train a high-frequency periodic function or modular multiplication from examples. We highlight that the variance of the gradient is negligibly small in both cases when either a frequency or the prime base is large. This in turn prevents such a learning algorithm from being successful.
Paper Structure (19 sections, 23 theorems, 122 equations, 6 figures)

This paper contains 19 sections, 23 theorems, 122 equations, 6 figures.

Key Result

Theorem 1

There exist universal constants $C, A_0>0$ such that for any $A>A_0$.

Figures (6)

  • Figure 1: Learning the high-frequency wave on ${\mathbb R}$ with a 3-layer dense network with ReLU activation. For each $A$, the coefficient $a$ was sampled randomly 5 times from ${\mathbb Z}_A$, and an average MSE loss as a function of an epoch is depicted. The horizontal asymptote corresponds to ${\rm MSE}=\frac{1}{12}$.
  • Figure 2: Verifying the statement of Theorem \ref{['r-var']}. For prime numbers $p$ in $[300,3000]$, we plot the left-hand side of \ref{['discrete']} divided by the average squared norm of the neural network's gradient \ref{['eq:g_w']} and multiplied by $\sqrt{p}$. The resulting curve is of order $\tilde{\mathcal{O}}(1)$. Moreover, it even decreases.
  • Figure 3: Learning the parity bit of multiplication modulo $p$ with a 3-layer width-1000 dense network. Darker shades correspond to longer bit lengths. For each bitlength $n$, $p$ is chosen randomly from the prime numbers in the interval $[2^{n-1}, 2^n-1]$.
  • Figure 4: $\mathop{\mathrm{\mathbb{E}}}\limits_{i\sim\{1,\ldots,20\}}\left[\frac{v(\mathbf{w}_i)}{g(\mathbf{w}_i)}\cdot p\right]$ against $p$ for modular multiplication.
  • Figure 5: Mean squared covariance between two multiplications, $a\cdot X$ and $b\cdot X$, when $X$ is a random variable uniformly distributed on $\mathbb{Z}_p^\ast$.
  • ...and 1 more figures

Theorems & Definitions (49)

  • Remark 1
  • Remark 2
  • Remark 3
  • Theorem 1
  • Lemma 2: Boas-Bellman inequality
  • proof : Proof of Theorem \ref{['ergodic']}
  • Lemma 3: Koksma-Hlawka inequality kuipers2012uniform
  • Lemma 4: Erdös-Turán-Koksma inequality kuipers2012uniform
  • Lemma 5
  • proof
  • ...and 39 more