Table of Contents
Fetching ...

Fourier Head: Helping Large Language Models Learn Complex Probability Distributions

Nate Gillman, Daksh Aggarwal, Michael Freeman, Saurabh Singh, Chen Sun

TL;DR

A neural network layer is introduced, constructed using Fourier series, which can easily substitute for any linear layer if the outputs of the LLMs have a more continuous structure, and theoretical evidence that this layer can better learn signal from data while ignoring high-frequency noise is provided.

Abstract

As the quality of large language models has improved, there has been increased interest in using them to model non-linguistic tokens. For example, the Decision Transformer recasts agentic decision making as a sequence modeling problem, using a decoder-only LLM to model the distribution over the discrete action space for an Atari agent. However, when adapting LLMs to non-linguistic domains, it remains unclear if softmax over discrete bins captures the continuous structure of the tokens and the potentially complex distributions needed for high quality token generation. We introduce a neural network layer, constructed using Fourier series, which we can easily substitute for any linear layer if we want the outputs to have a more continuous structure. We perform extensive analysis on synthetic datasets, as well as on large-scale decision making and time series forecasting tasks. We also provide theoretical evidence that this layer can better learn signal from data while ignoring high-frequency noise. All of our results support the effectiveness of our proposed Fourier head in scenarios where the underlying data distribution has a natural continuous structure. For example, the Fourier head improves a Decision Transformer agent's returns across four benchmark Atari games by as much as 377%, and increases a state-of-the-art times series foundation model's forecasting performance by 3.5% across 20 benchmarks unseen during training.

Fourier Head: Helping Large Language Models Learn Complex Probability Distributions

TL;DR

A neural network layer is introduced, constructed using Fourier series, which can easily substitute for any linear layer if the outputs of the LLMs have a more continuous structure, and theoretical evidence that this layer can better learn signal from data while ignoring high-frequency noise is provided.

Abstract

As the quality of large language models has improved, there has been increased interest in using them to model non-linguistic tokens. For example, the Decision Transformer recasts agentic decision making as a sequence modeling problem, using a decoder-only LLM to model the distribution over the discrete action space for an Atari agent. However, when adapting LLMs to non-linguistic domains, it remains unclear if softmax over discrete bins captures the continuous structure of the tokens and the potentially complex distributions needed for high quality token generation. We introduce a neural network layer, constructed using Fourier series, which we can easily substitute for any linear layer if we want the outputs to have a more continuous structure. We perform extensive analysis on synthetic datasets, as well as on large-scale decision making and time series forecasting tasks. We also provide theoretical evidence that this layer can better learn signal from data while ignoring high-frequency noise. All of our results support the effectiveness of our proposed Fourier head in scenarios where the underlying data distribution has a natural continuous structure. For example, the Fourier head improves a Decision Transformer agent's returns across four benchmark Atari games by as much as 377%, and increases a state-of-the-art times series foundation model's forecasting performance by 3.5% across 20 benchmarks unseen during training.

Paper Structure

This paper contains 31 sections, 9 theorems, 49 equations, 19 figures, 9 tables, 1 algorithm.

Key Result

Theorem 3.3

(Fourier head scaling law.) Consider a Fourier head with input dimension $n$, output dimension $m$, and $N$ frequencies. Suppose that $1 \ll N < \frac{m}{2}$. Then the following are true:

Figures (19)

  • Figure 1: We task an MLP with learning to approximate a continuous bimodal density using a categorical distribution and a cross-entropy objective. We observe that a standard linear head fails to distinguish between the two modes, and overfits to high-frequency noise in the training set. In contrast, our proposed Fourier head learns a smoother, more accurate categorical distribution.
  • Figure 2: Comparison between the PMFs learned by the linear head, GMM head, and the Fourier head, for two of the datasets in the toy example--Gaussian and Beta. (The GMM dataset is in Figure \ref{['fig:toy_example_1d_cropped']}.) We observe that the Fourier head learns a smoother categorical distribution than the linear head over its predicted values. Furthermore, the Fourier head better fits the true conditional PDF; this is reflected in the KL divergence and smoothness metrics.
  • Figure 3: We demonstrate that the baseline Llama model does a poor job simulating Gaussian sampling, as measured by the Total Variation Distance between the ground truth quantized Gaussian histogram, and the empirical histogram of samples. We find that LoRA fine-tuning improves the results by a factor of $\approx 2.07$, and that using the Fourier head improves the output distribution by a factor of $\approx4.86$.
  • Figure 4: We present empirical results for how the quantity of Fourier frequencies impacts returns and smoothness for the imitation learning task. For normalized returns, higher is better; for smoothness, lower is better. We can see that the Fourier agent achieves higher normalized returns than the linear baseline agent when sufficiently many Fourier frequencies are used, while still learning smoother next-action distributions.
  • Figure 5: Truncated square waves framed as densities and their smoothness.
  • ...and 14 more figures

Theorems & Definitions (17)

  • Definition 3.2: Smoothness metric for categorical distributions
  • Theorem 3.3
  • Definition A.1: Discrete convolution
  • Definition A.2: Continuous convolution
  • Lemma A.2
  • Lemma A.3: Asymptotic expansion of Riemann zeta function
  • Lemma A.3
  • Lemma A.3
  • Theorem A.3
  • proof : Proof of Claim 2 of Theorem \ref{['thm:scaling_law']}
  • ...and 7 more