Table of Contents
Fetching ...

Flash Multi-Head Feed-Forward Network

Minshen Zhang, Xiang Hu, Jianguo Li, Wei Wu, Kewei Tu

TL;DR

Problem: traditional FFNs dominate Transformer parameters and computation, and naïve Multi-Head FFN suffers memory and scaling inefficiencies. Approach: FlashMHF combines scale-balanced parallel FFN sub-networks with an IO-aware fused kernel to compute activations in SRAM, avoiding large intermediate tensors. Findings: across 128M–1.3B models, FlashMHF outperforms SwiGLU FFN on perplexity and downstream tasks, while reducing peak memory 3–5x and achieving up to 1.08x inference speedups. Significance: demonstrates that multi-head FFN can be a superior, scalable architectural principle for dense Transformer components, enabling more capable and efficient language models.

Abstract

We explore Multi-Head FFN (MH-FFN) as a replacement of FFN in the Transformer architecture, motivated by the structural similarity between single-head attention and FFN. While multi-head mechanisms enhance expressivity in attention, naively applying them to FFNs faces two challenges: memory consumption scaling with the head count, and an imbalanced ratio between the growing intermediate size and the fixed head dimension as models scale, which degrades scalability and expressive power. To address these challenges, we propose Flash Multi-Head FFN (FlashMHF), with two key innovations: an I/O-aware fused kernel computing outputs online in SRAM akin to FlashAttention, and a design using dynamically weighted parallel sub-networks to maintain a balanced ratio between intermediate and head dimensions. Validated on models from 128M to 1.3B parameters, FlashMHF consistently improves perplexity and downstream task accuracy over SwiGLU FFNs, while reducing peak memory usage by 3-5x and accelerating inference by up to 1.08x. Our work establishes the multi-head design as a superior architectural principle for FFNs, presenting FlashMHF as a powerful, efficient, and scalable alternative to FFNs in Transformers.

Flash Multi-Head Feed-Forward Network

TL;DR

Problem: traditional FFNs dominate Transformer parameters and computation, and naïve Multi-Head FFN suffers memory and scaling inefficiencies. Approach: FlashMHF combines scale-balanced parallel FFN sub-networks with an IO-aware fused kernel to compute activations in SRAM, avoiding large intermediate tensors. Findings: across 128M–1.3B models, FlashMHF outperforms SwiGLU FFN on perplexity and downstream tasks, while reducing peak memory 3–5x and achieving up to 1.08x inference speedups. Significance: demonstrates that multi-head FFN can be a superior, scalable architectural principle for dense Transformer components, enabling more capable and efficient language models.

Abstract

We explore Multi-Head FFN (MH-FFN) as a replacement of FFN in the Transformer architecture, motivated by the structural similarity between single-head attention and FFN. While multi-head mechanisms enhance expressivity in attention, naively applying them to FFNs faces two challenges: memory consumption scaling with the head count, and an imbalanced ratio between the growing intermediate size and the fixed head dimension as models scale, which degrades scalability and expressive power. To address these challenges, we propose Flash Multi-Head FFN (FlashMHF), with two key innovations: an I/O-aware fused kernel computing outputs online in SRAM akin to FlashAttention, and a design using dynamically weighted parallel sub-networks to maintain a balanced ratio between intermediate and head dimensions. Validated on models from 128M to 1.3B parameters, FlashMHF consistently improves perplexity and downstream task accuracy over SwiGLU FFNs, while reducing peak memory usage by 3-5x and accelerating inference by up to 1.08x. Our work establishes the multi-head design as a superior architectural principle for FFNs, presenting FlashMHF as a powerful, efficient, and scalable alternative to FFNs in Transformers.

Paper Structure

This paper contains 21 sections, 17 equations, 8 figures, 5 tables, 5 algorithms.

Figures (8)

  • Figure 1: Structural Symmetry.
  • Figure 2: Memory limitation of MH-FFN.
  • Figure 3: (a) Parallel FFN Sub-Networks. (b) $\widetilde{\mathrm{SRAMFFN}}$ loads blocks of $\mathbf{Q}$ in the outer loop and blocks of $\mathbf{K},\mathbf{U},\mathbf{V}$ in the inner loop, compute $\text{SiLU}(\mathbf{Q}\mathbf{K}^{\top})$, $\mathbf{Q}\mathbf{U}^{\top}$ and corresponding $\mathbf{V}$ multiplication on SRAM.
  • Figure 4: Comparing Baseline, Parametric KV, FlashMHF and MH-FFN in 128M and 370M scales.
  • Figure 5: Training on 370M model scale to investigate the best head dimension. Analysis: (a) is full training loss, to visualize it more clearly, we zoom in to later training steps as illustrated in (b) and (c). Our FlashMHF with $d_h\!=\!64,128$ gets better train/evaluation loss on PG19 validation split.
  • ...and 3 more figures