Table of Contents
Fetching ...

Wavelet GPT: Wavelet Inspired Large Language Models

Prateek Verma

TL;DR

This work introduces Wavelet GPT, a wavelet-inspired mechanism that imposes a multi-scale structure on intermediate embeddings in a GPT-style decoder to exploit the multi-scale nature of real-world data. By applying causal Haar wavelet transforms (or learnable equivalents) to a subset of embedding coordinates, the model accesses multiple temporal resolutions during next-token prediction without adding parameters. Across text, symbolic music, and raw audio, the approach yields 40–60% faster pre-training while achieving comparable or better performance, with strong gains on Long Range Arena benchmarks. The method generalizes to diverse input representations and demonstrates that multi-scale, signal-processing-inspired priors can meaningfully accelerate learning in large language models while preserving causality and efficiency.

Abstract

Large Language Models (LLMs) have ushered in a new wave of artificial intelligence advancements impacting every scientific field and discipline. We live in a world where most of the data around us, e.g., text, audio, and music, has a multi-scale structure. This paper infuses LLMs with a traditional signal processing idea, namely wavelets, during pre-training to take advantage of the structure. Without adding \textbf{any extra parameters} to a GPT-style LLM architecture in an academic setup, we achieve the same pre-training performance almost twice as fast in text, audio, and images. This is done by imposing a structure on intermediate embeddings. When trained for the same number of training steps, we achieve significant gains in performance, which is comparable to pre-training a larger neural architecture. Further, we show this extends to the Long Range Arena benchmark and several input representations such as characters, BPE tokens, bytes, waveform, math expression, and image pixels. Our architecture allows every next token prediction access to intermediate embeddings at different temporal resolutions in every decoder block. We hope this will pave the way for incorporating multi-rate signal processing into pre-training.

Wavelet GPT: Wavelet Inspired Large Language Models

TL;DR

This work introduces Wavelet GPT, a wavelet-inspired mechanism that imposes a multi-scale structure on intermediate embeddings in a GPT-style decoder to exploit the multi-scale nature of real-world data. By applying causal Haar wavelet transforms (or learnable equivalents) to a subset of embedding coordinates, the model accesses multiple temporal resolutions during next-token prediction without adding parameters. Across text, symbolic music, and raw audio, the approach yields 40–60% faster pre-training while achieving comparable or better performance, with strong gains on Long Range Arena benchmarks. The method generalizes to diverse input representations and demonstrates that multi-scale, signal-processing-inspired priors can meaningfully accelerate learning in large language models while preserving causality and efficiency.

Abstract

Large Language Models (LLMs) have ushered in a new wave of artificial intelligence advancements impacting every scientific field and discipline. We live in a world where most of the data around us, e.g., text, audio, and music, has a multi-scale structure. This paper infuses LLMs with a traditional signal processing idea, namely wavelets, during pre-training to take advantage of the structure. Without adding \textbf{any extra parameters} to a GPT-style LLM architecture in an academic setup, we achieve the same pre-training performance almost twice as fast in text, audio, and images. This is done by imposing a structure on intermediate embeddings. When trained for the same number of training steps, we achieve significant gains in performance, which is comparable to pre-training a larger neural architecture. Further, we show this extends to the Long Range Arena benchmark and several input representations such as characters, BPE tokens, bytes, waveform, math expression, and image pixels. Our architecture allows every next token prediction access to intermediate embeddings at different temporal resolutions in every decoder block. We hope this will pave the way for incorporating multi-rate signal processing into pre-training.
Paper Structure (15 sections, 3 equations, 6 figures, 1 table)

This paper contains 15 sections, 3 equations, 6 figures, 1 table.

Figures (6)

  • Figure 1: Manipulating signals between GPT decoder blocks by computing 1-D causal discrete haar wavelet transform/learnable approximation at different levels capturing multi-scale structure for each signal. (Right) From gao2006non explaining non-stationary signal processing for signals. Leftmost route of approximate coefficients to model coarsest to finest scales.
  • Figure 2: (Bottom L): A 3-level filter bank tree generates signals at different resolutions. Approximate coefficients are computed by applying a wavelet's impulse response & recursively down-sampling. (Top L): Approximate and detailed coefficients are iteratively calculated via first-order averages/differences and down-sampling until a single scalar represents the signal. (R): For a 32-length signal, Haar wavelet captures coarsest to finest approximations and is redrawn from wavelet-tutorial. Embeddings evolve at different rates via causal wavelet approximation, with coarse (level 5) and fine (level 2) resolutions, embedding multi-scale information into decoder layers for every token.
  • Figure 3: (Left) Toy example showing embeddings before/after imposing multi-rate structure. Different embedding dimensions advance at distinct rates while maintaining causality, as seen from patterns dispersing from dimension 64 to 0. (Right) Validation loss during pre-training on text-8 with learnable multi-scale structure achieving comparable performance nearly twice as fast/performance boost akin to adding additional decoder layers. Our architecture's performance on text-8 with a 32-dim model matches the speedup similar to that seen for 128-dim and shallower models. LRA image benchmark, a 10% performance increase without adding any parameters
  • Figure 4: Results for natural language, symbolic music, and raw audio. We perform faster than baseline, almost twice as fast on shrunk-down GPT. We see substantial gains in pre-training performance for the same epochs, equivalent to a much larger architecture. The black vertical line denotes the epoch at which our architecture achieves the same performance as our baseline architecture.
  • Figure 5: Comparison of the negative-log likelihood (NLL) scores for our architecture across three modalities, with and without wavelet-based fixed/learnable (L) structure. (Left) Table shows the NLL scores and speedup, with Same Performance Epoch (SPE) with baseline as 25 epochs, relative GPU hours. (R) FSD-50K Audio Transformer top-5 accuracy results. Vertical green lines indicate the highest accuracy achieved and the point where the same accuracy is reached 60% faster using our proposed method, with no extra parameters.
  • ...and 1 more figures