Mask Tokens as Prophet: Fine-Grained Cache Eviction for Efficient dLLM Inference
Jianuo Huang, Yaojie Zhang, Yicun Yang, Benhao Huang, Biqing Qi, Dongrui Liu, Linfeng Zhang
TL;DR
This work tackles the high memory and computation cost of KV caches in diffusion LLMs by engineering a data‑driven, training‑free eviction framework called MaskKV. It introduces Mask‑Voting to identify pivotal prompt tokens via mask‑query attention and a two‑stage, offline layer‑then‑head budgeting strategy to allocate a fixed KV budget efficiently, avoiding eviction of critical components. Empirical results on LongBench with LLaDA‑8B‑Instruct and Dream‑7B‑Instruct show MaskKV preserves up to 94% of full‑cache performance with only 256 KV pairs and delivers up to 31× speedups at 32k contexts, while reducing peak memory substantially. The approach provides a practical, scalable path to long‑context diffusion LLM inference, with broad implications for efficient bidirectional attention in large language models.
Abstract
Diffusion large language models (dLLMs) present a promising alternative to dominant autoregressive models (ARMs) by the ability of parallel decoding at the expense of substantial computation and memory costs. Specifically, the cache mechanism for bidirectional attention in dLLMs demands large memory footprint, restricting their ability to handle long contexts under resource-limited settings. Existing cache eviction strategies are designed for ARMs and ignore the unique characteristics of dLLMs, thus leading to unsatisfactory performance. To address these challenges, we introduce MaskKV, a training-free cache eviction framework tailored to dLLMs, focusing on the effect of mask tokens in dLLMs. MaskKV is built on two key innovations: (1) a mask-query guided scoring mechanism that leverages attention weights to identify and evict less critical prompt tokens for each head; (2) an adaptive cache budgeting strategy that improves efficiency by reducing allocation in intermediate layers and concentrating resources on prompt-preferring heads. On LLaDA with MaskKV, compressing the KV cache to only 256 pairs (less than 5% of tokens) retains 94% of the full-cache performance on LongBench and achieves up to 31x acceleration at 32k prompt length. The code is publicly available at: https://github.com/jianuo-huang/MaskKV
