Table of Contents
Fetching ...

Fourier Circuits in Neural Networks and Transformers: A Case Study of Modular Arithmetic with Multiple Inputs

Chenyang Li, Yingyu Liang, Zhenmei Shi, Zhao Song, Tianyi Zhou

TL;DR

This work analyzes how neural nets and Transformers learn modular addition tasks by forming Fourier-based circuits. By casting training as a max-margin problem and transferring to discrete Fourier space, it shows that, with sufficient width, each hidden neuron selects a single Fourier frequency, enabling exact modular-addition solutions. Empirical results with one-hidden-layer nets and one-layer Transformers corroborate the theoretical Fourier-patterns in hidden weights and attention, and grokking behavior is observed and interpreted through margin-based theory. The findings provide mechanistic insight into how SGD-driven learning organizes representations for algebraic tasks and suggest implications for robustness and algorithm design. The study also discusses connections to parity, SQ hardness, and implicit bias, outlining limitations and future directions for extending these results to broader architectures and tasks.

Abstract

In the evolving landscape of machine learning, a pivotal challenge lies in deciphering the internal representations harnessed by neural networks and Transformers. Building on recent progress toward comprehending how networks execute distinct target functions, our study embarks on an exploration of the underlying reasons behind networks adopting specific computational strategies. We direct our focus to the complex algebraic learning task of modular addition involving $k$ inputs. Our research presents a thorough analytical characterization of the features learned by stylized one-hidden layer neural networks and one-layer Transformers in addressing this task. A cornerstone of our theoretical framework is the elucidation of how the principle of margin maximization shapes the features adopted by one-hidden layer neural networks. Let $p$ denote the modulus, $D_p$ denote the dataset of modular arithmetic with $k$ inputs and $m$ denote the network width. We demonstrate that a neuron count of $ m \geq 2^{2k-2} \cdot (p-1) $, these networks attain a maximum $ L_{2,k+1} $-margin on the dataset $ D_p $. Furthermore, we establish that each hidden-layer neuron aligns with a specific Fourier spectrum, integral to solving modular addition problems. By correlating our findings with the empirical observations of similar studies, we contribute to a deeper comprehension of the intrinsic computational mechanisms of neural networks. Furthermore, we observe similar computational mechanisms in attention matrices of one-layer Transformers. Our work stands as a significant stride in unraveling their operation complexities, particularly in the realm of complex algebraic tasks.

Fourier Circuits in Neural Networks and Transformers: A Case Study of Modular Arithmetic with Multiple Inputs

TL;DR

This work analyzes how neural nets and Transformers learn modular addition tasks by forming Fourier-based circuits. By casting training as a max-margin problem and transferring to discrete Fourier space, it shows that, with sufficient width, each hidden neuron selects a single Fourier frequency, enabling exact modular-addition solutions. Empirical results with one-hidden-layer nets and one-layer Transformers corroborate the theoretical Fourier-patterns in hidden weights and attention, and grokking behavior is observed and interpreted through margin-based theory. The findings provide mechanistic insight into how SGD-driven learning organizes representations for algebraic tasks and suggest implications for robustness and algorithm design. The study also discusses connections to parity, SQ hardness, and implicit bias, outlining limitations and future directions for extending these results to broader architectures and tasks.

Abstract

In the evolving landscape of machine learning, a pivotal challenge lies in deciphering the internal representations harnessed by neural networks and Transformers. Building on recent progress toward comprehending how networks execute distinct target functions, our study embarks on an exploration of the underlying reasons behind networks adopting specific computational strategies. We direct our focus to the complex algebraic learning task of modular addition involving inputs. Our research presents a thorough analytical characterization of the features learned by stylized one-hidden layer neural networks and one-layer Transformers in addressing this task. A cornerstone of our theoretical framework is the elucidation of how the principle of margin maximization shapes the features adopted by one-hidden layer neural networks. Let denote the modulus, denote the dataset of modular arithmetic with inputs and denote the network width. We demonstrate that a neuron count of , these networks attain a maximum -margin on the dataset . Furthermore, we establish that each hidden-layer neuron aligns with a specific Fourier spectrum, integral to solving modular addition problems. By correlating our findings with the empirical observations of similar studies, we contribute to a deeper comprehension of the intrinsic computational mechanisms of neural networks. Furthermore, we observe similar computational mechanisms in attention matrices of one-layer Transformers. Our work stands as a significant stride in unraveling their operation complexities, particularly in the realm of complex algebraic tasks.
Paper Structure (45 sections, 17 theorems, 127 equations, 10 figures)

This paper contains 45 sections, 17 theorems, 127 equations, 10 figures.

Key Result

Lemma 3.7

Let $f$ be a homogeneous function. For any norm $\| \cdot \|$, if $\gamma^* > 0$, we have $\lim_{\lambda\rightarrow 0} \gamma_\lambda = \gamma^*$.

Figures (10)

  • Figure 1: Cosine shape of the trained embeddings (hidden layer weights) and corresponding power of Fourier spectrum. The two-layer network with $m=2944$ neurons is trained on $k=4$-sum mod-$p=47$ addition dataset. We even split the whole datasets ($p^k = 47^4$ data points) into the training and test datasets. Every row represents a random neuron from the network. The left figure shows the final trained embeddings, with red dots indicating the true weight values, and the pale blue interpolation is achieved by identifying the function that shares the same Fourier spectrum. The right figure shows their Fourier power spectrum. The results in these figures are consistent with our analysis statements in Lemma \ref{['lem:margin_soln-k:informal']}. See Figure \ref{['fig:nn_w_k3']}, \ref{['fig:nn_w_k5']} in Appendix \ref{['app:sec:exp_nn']} for similar results when $k$ is 3 or 5.
  • Figure 2: All Fourier spectrum frequencies being covered and the maximum normalized power of the embeddings (hidden layer weights). The one-hidden layer network with $m=2944$ neurons is trained on $k=4$-sum mod-$p=47$ addition dataset. We denote $\widehat{u}[i]$ as the Fourier transform of $u[i]$. Let $\max_i |\widehat{u}[i]|^2 /( \sum|\widehat{u}[j]|^2 )$ be the maximum normalized power. Mapping each neuron to its maximum normalized power frequency, (a) shows the final frequency distribution of the embeddings. Similar to our construction analysis in Lemma \ref{['lem:construct-k:informal']}, we have an almost uniform distribution over all frequencies. (b) shows the maximum normalized power of the neural network with random initialization. (c) shows, in frequency space, the embeddings of the final trained network are one-sparse, i.e., maximum normalized power being almost 1 for all neurons. This is consistent with our max-margin analysis results in Lemma \ref{['lem:construct-k:informal']}. See Figure \ref{['fig:nn_freq_k3']} and \ref{['fig:nn_freq_k5']} in Appendix \ref{['app:sec:exp_nn']} for results when $k$ is 3 or 5.
  • Figure 3: 2-dimension cosine shape of the trained $W^{KQ}$ (attention weights) and their Fourier power spectrum. The one-layer transformer with attention heads $m=160$ is trained on $k=4$-sum mod-$p=31$ addition dataset. We even split the whole datasets ($p^k = 31^4$ data points) into training and test datasets. Every row represents a random attention head from the transformer. The left figure shows the final trained attention weights being an apparent 2-dim cosine shape. The right figure shows their 2-dim Fourier power spectrum. The results in the figures are consistent with Figure \ref{['fig:nn_w_k4']}. See Figure \ref{['fig:s_k3']} and Figure \ref{['fig:s_k5']} in Appendix \ref{['app:sec:exp_transformer']} for similar results when $k$ is 3 or 5.
  • Figure 4: Grokking (models abruptly transition from bad generalization to perfect generalization after a large number of training steps) under learning modular addition involving $k=2,3,4,5$ inputs. We train two-layer transformers with $m=160$ attention heads on $k=2,3,4,5$-sum mod-$p = 97,31,11,5$ addition dataset with $50$% of the data in the training set under AdamW loshchilov2017decoupled optimizer 1e-3 learning rate and 1e-3 weight decay. We use different $p$ to guarantee the dataset sizes are roughly equal to each other. The blue curves show training accuracy, and the red ones show validation accuracy. There is a grokking phenomenon in all figures. However, as $k$ increases, the grokking phenomenon becomes weak. See explanation in Section \ref{['sec:exp']}.
  • Figure 5: Cosine shape of the trained embeddings (hidden layer weights) and corresponding power of Fourier spectrum. The two-layer network with $m=1536$ neurons is trained on $k=3$-sum mod-$p=97$ addition dataset. We even split the whole datasets ($p^k = 97^3$ data points) into the training and test datasets. Every row represents a random neuron from the network. The left figure shows the final trained embeddings, with red dots indicating the true weight values, and the pale blue interpolation is achieved by identifying the function that shares the same Fourier spectrum. The right figure shows their Fourier power spectrum. The results in these figures are consistent with our analysis statements in Lemma \ref{['lem:margin_soln-k:informal']}.
  • ...and 5 more figures

Theorems & Definitions (45)

  • Definition 3.1
  • Definition 3.2
  • Definition 3.3
  • Definition 3.4
  • Definition 3.5
  • Definition 3.6
  • Lemma 3.7: wei2019regularization, Theorem 4.1
  • Theorem 4.1: Main result, informal version of Theorem \ref{['thm:main_k:formal']}
  • proof : Proof sketch of Theorem \ref{['thm:main_k:informal']}
  • Lemma 4.2: Informal version of Lemma \ref{['lem:margin_soln-k']}
  • ...and 35 more