AMLA: MUL by ADD in FlashAttention Rescaling
Qichen Liao, Chengqiu Hu, Fangzheng Miao, Bao Li, Yiyang Liu, Junlong Lyu, Lirui Jiang, Jun Wang, Lingchao Zheng, Jun Li, Yuwei Fan
TL;DR
This work tackles the decode‑phase bottlenecks of Multi‑Head Latent Attention (MLA) by introducing Ascend MLA (AMLA), a co‑designed kernel for Huawei Ascend NPUs. It replaces FP32 multiplications in the output rescaling with integer additions via a binary FP32–INT32 reinterpretation (F × 2^n = AS_FP32(AS_INT32(F) + n × 2^{23})) and performs in‑GM updates through AtomicAdd to eliminate data movement of large intermediate tensors. AMLA also introduces a Preload Pipeline and hierarchical tiling to overlap Cube and Vector work and maximize FLOPS utilization, achieving up to 86.8% FU on Ascend 910 and surpassing FlashMLA on contemporary GPUs. The approach yields stable numerical results and is integrated into Huawei’s CANN, with plans for public release, offering substantial practical impact for efficient long‑context decoding in LLMs.
Abstract
Multi-head Latent Attention (MLA) significantly reduces KVCache memory usage in Large Language Models while introducing substantial computational overhead and intermediate variable expansion. This poses challenges for efficient hardware implementation -- especially during the decode phase. This paper introduces Ascend MLA (AMLA), a high-performance kernel specifically optimized for Huawei's Ascend NPUs. AMLA is built on two core innovations: (1) A novel FlashAttention-based algorithm that replaces floating-point multiplications with integer additions for output block rescaling, leveraging binary correspondence between FP32 and INT32 representations; (2) A Preload Pipeline strategy with hierarchical tiling that maximizes FLOPS utilization: the Preload Pipeline achieves Cube-bound performance, while hierarchical tiling overlaps data movement and computation within the Cube core. Experiments show that on Ascend 910 NPUs (integrated in CloudMatrix384), AMLA achieves up to 614 TFLOPS, reaching 86.8% of the theoretical maximum FLOPS, outperforming the state-of-the-art open-source FlashMLA implementation, whose FLOPS utilization is up to 66.7% on NVIDIA H800 SXM5. The AMLA kernel has been integrated into Huawei's CANN and will be released soon.
