Table of Contents
Fetching ...

Rethinking Attention Output Projection: Structured Hadamard Transforms for Efficient Transformers

Shubham Aggarwal, Lokendra Kumar

TL;DR

Interestingly, it is observed that structured Hadamard-based models exhibit a steeper validation loss curve relative to training FLOPs compared to their dense counterparts, suggesting more favorable compute utilization during training.

Abstract

The dense output projection in multi-head attention scales quadratically with model dimension, contributing significantly to parameter count, memory footprint, and inference cost. We propose replacing this projection with a fixed, parameter-free Walsh Hadamard Transform followed by a lightweight learnable affine rescaling, eliminating approximately 25 percent of attention parameters per block while preserving global cross head interaction through an orthogonal, norm-preserving transformation. Across different model sizes, we demonstrate that this structured substitution maintains comparable or slightly superior downstream task performance on standard benchmarks, while achieving up to 7 percent aggregate parameter reduction, 8.9 percent peak memory savings, and 6.6 percent throughput improvement at scale, with efficiency gains growing monotonically with model size, batch size, and sequence length. Interestingly, we observe that structured Hadamard-based models exhibit a steeper validation loss curve relative to training FLOPs compared to their dense counterparts, suggesting more favorable compute utilization during training.

Rethinking Attention Output Projection: Structured Hadamard Transforms for Efficient Transformers

TL;DR

Interestingly, it is observed that structured Hadamard-based models exhibit a steeper validation loss curve relative to training FLOPs compared to their dense counterparts, suggesting more favorable compute utilization during training.

Abstract

The dense output projection in multi-head attention scales quadratically with model dimension, contributing significantly to parameter count, memory footprint, and inference cost. We propose replacing this projection with a fixed, parameter-free Walsh Hadamard Transform followed by a lightweight learnable affine rescaling, eliminating approximately 25 percent of attention parameters per block while preserving global cross head interaction through an orthogonal, norm-preserving transformation. Across different model sizes, we demonstrate that this structured substitution maintains comparable or slightly superior downstream task performance on standard benchmarks, while achieving up to 7 percent aggregate parameter reduction, 8.9 percent peak memory savings, and 6.6 percent throughput improvement at scale, with efficiency gains growing monotonically with model size, batch size, and sequence length. Interestingly, we observe that structured Hadamard-based models exhibit a steeper validation loss curve relative to training FLOPs compared to their dense counterparts, suggesting more favorable compute utilization during training.
Paper Structure (29 sections, 8 equations, 8 figures, 4 tables)

This paper contains 29 sections, 8 equations, 8 figures, 4 tables.

Figures (8)

  • Figure 1: Comparative diagram of dense and Hadamard layers.
  • Figure 2: Computational flow of the Fast Walsh-Hadamard Transform (FWHT) showing the butterfly network structure with three stages
  • Figure 3: Forward FLOPs for dense matrix multiplication ($c^{2}$) versus the Fast Walsh--Hadamard Transform ($c\log_{2}c$) across embedding dimensions used in GPT-2 (base: $c{=}768$). FWHT requires zero stored weights.
  • Figure 4: Comparison of three baseline models and three variants of our method across different sizes. Our models converge slightly slower (left) but show better scaling of validation loss with compute, with a steeper trend vs. FLOPs (right).
  • Figure 5: Prefill latency (ms) and throughput (tokens/s) as a function of sequence length at batch size 128. Our method consistently reduces latency and increases throughput across all model scales, with gains widening at longer sequences.
  • ...and 3 more figures