Table of Contents
Fetching ...

Attention Learning is Needed to Efficiently Learn Parity Function

Yaomengxi Han, Debarghya Ghoshdastidar

TL;DR

This paper analyzes learning the $k$-parity problem under a realizable, noiseless setting and compares two architectures: two-layer FFNNs and transformers with multi-head attention. It proves that transformers with $k$ trainable attention heads can learn $k$-parity using only $O(k)$ parameters, while FFNNs require at least $Ω(n)$ parameters, establishing a strong parameter-efficiency advantage for attention-based learning. Moreover, it shows that this advantage vanishes if attention heads are fixed, requiring norms and head counts to scale with $n$, thereby highlighting the necessity of learning attention. The results imply that attention mechanisms enable efficient feature learning for low-sensitivity tasks, with potential implications for understanding generalization dynamics in practical transformer models and guiding future work on broader low-sensitivity function classes.

Abstract

Transformers, with their attention mechanisms, have emerged as the state-of-the-art architectures of sequential modeling and empirically outperform feed-forward neural networks (FFNNs) across many fields, such as natural language processing and computer vision. However, their generalization ability, particularly for low-sensitivity functions, remains less studied. We bridge this gap by analyzing transformers on the $k$-parity problem. Daniely and Malach (NeurIPS 2020) show that FFNNs with one hidden layer and $O(nk^7 \log k)$ parameters can learn $k$-parity, where the input length $n$ is typically much larger than $k$. In this paper, we prove that FFNNs require at least $Ω(n)$ parameters to learn $k$-parity, while transformers require only $O(k)$ parameters, surpassing the theoretical lower bound needed by FFNNs. We further prove that this parameter efficiency cannot be achieved with fixed attention heads. Our work establishes transformers as theoretically superior to FFNNs in learning parity function, showing how their attention mechanisms enable parameter-efficient generalization in functions with low sensitivity.

Attention Learning is Needed to Efficiently Learn Parity Function

TL;DR

This paper analyzes learning the -parity problem under a realizable, noiseless setting and compares two architectures: two-layer FFNNs and transformers with multi-head attention. It proves that transformers with trainable attention heads can learn -parity using only parameters, while FFNNs require at least parameters, establishing a strong parameter-efficiency advantage for attention-based learning. Moreover, it shows that this advantage vanishes if attention heads are fixed, requiring norms and head counts to scale with , thereby highlighting the necessity of learning attention. The results imply that attention mechanisms enable efficient feature learning for low-sensitivity tasks, with potential implications for understanding generalization dynamics in practical transformer models and guiding future work on broader low-sensitivity function classes.

Abstract

Transformers, with their attention mechanisms, have emerged as the state-of-the-art architectures of sequential modeling and empirically outperform feed-forward neural networks (FFNNs) across many fields, such as natural language processing and computer vision. However, their generalization ability, particularly for low-sensitivity functions, remains less studied. We bridge this gap by analyzing transformers on the -parity problem. Daniely and Malach (NeurIPS 2020) show that FFNNs with one hidden layer and parameters can learn -parity, where the input length is typically much larger than . In this paper, we prove that FFNNs require at least parameters to learn -parity, while transformers require only parameters, surpassing the theoretical lower bound needed by FFNNs. We further prove that this parameter efficiency cannot be achieved with fixed attention heads. Our work establishes transformers as theoretically superior to FFNNs in learning parity function, showing how their attention mechanisms enable parameter-efficient generalization in functions with low sensitivity.

Paper Structure

This paper contains 20 sections, 15 theorems, 62 equations, 2 figures.

Key Result

Proposition 3

Assume $k\leq n$. There exists a hypothesis class $\mathcal{H}_{\text{FFNN-}1^k}\subseteq \mathcal{H}_{\text{FFNN-}1}$ that expresses $k$-parity, and each $h\in \mathcal{H}_{\text{FFNN-}1^k}$ has exactly $k$ neurons and $2k+2$ distinct parameters. Furthermore, there exists a class $\mathcal{H}'_{\ba

Figures (2)

  • Figure 1: The architecture of the transformer and the example workflow to classify the parity of some given input. Given a binary string that consists of $7$ tokens as input, the embedding layer (in green) will embed each token into a concatenation of a positional embedding and a token embedding $\mathbf w_j = f_{\text{pos}}(j) \circ f_{\text{emb}}(x_j)$. An extra token embedding $\mathbf w_0$ will be prepended as the embedding of the CLS token. In the encoding layer (in red), each attention head $i$ will calculate attention scores $\boldsymbol{\gamma_i}$ for all of the seven embeddings with softmax. Then, each head will calculate its own vector $\mathbf v_i$ by taking the sum of the $7$ embeddings weighted by its own attention score: $\mathbf v_i = \sum_{j=1}^n \gamma^{(i)}_j\cdot\mathbf w_j$. These vectors will then be averaged into an attention vector $\mathbf v^* = \frac{1}{m}\sum_{i\in[m]}\mathbf v_i$, which will be the input of the two-layer feed-forward neural network (in blue).
  • Figure 2: Two heat maps of soft attention training. When there are no neighboring bits each head attend to a separate bit with a score very close to 1 (sub-figure on the left). If there exist neighboring bits (sub-figure on the right), a pair of attention heads could learn the same direction, which is in the middle of the positional embeddings of the neighboring bits.

Theorems & Definitions (21)

  • Definition 1: $k$-parity learning
  • Definition 2: Expressivity and learnability of $\mathcal{H}$
  • Proposition 3: Number of parameters needed by FFNNs and transformers to express $k$-parity
  • Proposition 4: Number of parameters needed by FFNNs to learn $k$-parity
  • Theorem 5: Transformers with learnable attention heads can learn $k$-parity
  • Lemma 6: smoothness of $\hat{y}$
  • Lemma 7: smoothness of $\mathcal{L}_{\mathcal{D}}$
  • Lemma 8: $\mu$-PL condition on the expected risk
  • Remark 9: Transformers are more parameter-efficient than FFNNs for learning $k$-parity
  • Remark 10: Transformers learn $k$-parity with uniform distributions and any $k$
  • ...and 11 more