Table of Contents
Fetching ...

Mixture-of-Depths Attention

Lianghui Zhu, Yuxin Fang, Bencheng Liao, Shijie Wang, Tianheng Cheng, Zilong Huang, Chen Chen, Lai Wei, Yutao Zeng, Ya Wang, Yi Lin, Yu Li, Xinggang Wang

Abstract

Scaling depth is a key driver for large language models (LLMs). Yet, as LLMs become deeper, they often suffer from signal degradation: informative features formed in shallow layers are gradually diluted by repeated residual updates, making them harder to recover in deeper layers. We introduce mixture-of-depths attention (MoDA), a mechanism that allows each attention head to attend to sequence KV pairs at the current layer and depth KV pairs from preceding layers. We further describe a hardware-efficient algorithm for MoDA that resolves non-contiguous memory-access patterns, achieving 97.3% of FlashAttention-2's efficiency at a sequence length of 64K. Experiments on 1.5B-parameter models demonstrate that MoDA consistently outperforms strong baselines. Notably, it improves average perplexity by 0.2 across 10 validation benchmarks and increases average performance by 2.11% on 10 downstream tasks, with a negligible 3.7% FLOPs computational overhead. We also find that combining MoDA with post-norm yields better performance than using it with pre-norm. These results suggest that MoDA is a promising primitive for depth scaling. Code is released at https://github.com/hustvl/MoDA .

Mixture-of-Depths Attention

Abstract

Scaling depth is a key driver for large language models (LLMs). Yet, as LLMs become deeper, they often suffer from signal degradation: informative features formed in shallow layers are gradually diluted by repeated residual updates, making them harder to recover in deeper layers. We introduce mixture-of-depths attention (MoDA), a mechanism that allows each attention head to attend to sequence KV pairs at the current layer and depth KV pairs from preceding layers. We further describe a hardware-efficient algorithm for MoDA that resolves non-contiguous memory-access patterns, achieving 97.3% of FlashAttention-2's efficiency at a sequence length of 64K. Experiments on 1.5B-parameter models demonstrate that MoDA consistently outperforms strong baselines. Notably, it improves average perplexity by 0.2 across 10 validation benchmarks and increases average performance by 2.11% on 10 downstream tasks, with a negligible 3.7% FLOPs computational overhead. We also find that combining MoDA with post-norm yields better performance than using it with pre-norm. These results suggest that MoDA is a promising primitive for depth scaling. Code is released at https://github.com/hustvl/MoDA .
Paper Structure (22 sections, 7 equations, 5 figures, 7 tables, 1 algorithm)

This paper contains 22 sections, 7 equations, 5 figures, 7 tables, 1 algorithm.

Figures (5)

  • Figure 1: We propose mixture-of-depths attention (MoDA) to address the modern LLM's information dilution problem in a dynamic and hardware-efficient way. Compared with vanilla causal sequence attention, MoDA additionally allows query to attend to depth memories, i.e., depth KV pairs $\{K_i, V_i\}_{i=0}^{l-1}$ at the same query position from preceding layers.
  • Figure 2: Comparing MoDA and strong open-sourced baseline, i.e., OLMo2 olmo2024olmo2, with validation loss and downstream performance under the 1.5B-parameter setting. Models using MoDA achieve lower C4 raffel2020c4 validation loss and better downstream performance, i.e., HellaSwag zellers2019hellaswag, WinoGrande sakaguchi2021winogrande, and ARC-Challenge clark2018arc, than OLMo2.
  • Figure 3: Conceptual comparison of mechanisms that utilize the depth stream. (a) Depth Residualhe2016resnet is the standard residual connection along depth: it reads the current representation and writes back by addition. (b) Depth Densehuang2017densenetpagliardini2024denseformer reads a set of historical representations and linearly projects them back to width $D$; it writes back by concatenation along depth, preserving all intermediate states. (c) We introduce Depth Attention as an intermediate formulation, which uses attention to read historical depth KV pairs in a data-dependent way. It writes back by concatenating the current layer's keys and values along depth. (d) We propose the upgraded version of Depth Attention, i.e., Mixture-of-Depths Attention (MoDA), which combines depth attention with standard sequence attention. It writes both the current layer's output and its KV pairs to depth streams for subsequent layers.
  • Figure 4: Hardware view of MoDA depth-cache access. Left: flash-compatible hardware-efficient MoDA keeps a depth KV cache of length $T\times L$ for each sequence, so each query potentially scans a long concatenated depth KV. Right: chunk-aware MoDA groups queries by chunk size $C$ and reorganizes depth KV by chunk, reducing the effective depth span from $T\times L$ to $(C\times L)/G$ per chunk, where $G$ is the GQA group number. This layout improves depth KV calculation efficiency and reduces memory access overhead.
  • Figure 5: Mixture-of-Depths Attention (MoDA) heatmaps with the combined-softmax formulation. Columns correspond to uniformly sampled layers $\{0, 11, 23, 35\}$, and rows correspond to randomly selected heads in each layer. The first column shows attention over sequence KV only, while the other columns show concatenated Sequence KV | Depth KV; the red dashed line marks the boundary between the two KV blocks. Across layers and heads, substantial attention mass is consistently assigned to the depth-KV block, indicating that MoDA effectively leverages depth information in addition to standard sequence attention.