NeuronMM: High-Performance Matrix Multiplication for LLM Inference on AWS Trainium
Dinghong Song, Jierui Xu, Weichu Yang, Pengfei Su, Dong Li
TL;DR
NeuronMM targets the bottlenecks of LLM inference on AWS Trainium by combining block-aligned SVD compression with a Trainium-specific kernel fusion strategy (TrainiumFusion). The approach transforms large matmuls into a hardware-friendly XUV computation, minimizes data movement through on-chip caching and two-stage on-chip execution, and eliminates intermediate transposes via implicit transposition and tiling. Empirical results across nine datasets and four LLMs show substantial gains: an average matmul kernel speedup of 1.35x (up to 2.22x) and an end-to-end speedup of 1.66x (up to 2.49x), with modest accuracy loss recoverable by LoRA fine-tuning. The work demonstrates a practical hardware–algorithm co-design that leverages Trainium’s SRAM/Tensor Engine and memory hierarchy, and provides open-source deployment to extend Trainium’s inference efficiency.
Abstract
AI accelerators, customized to AI workloads, provide cost-effective and high-performance solutions for training and inference. Trainium, an AI accelerator recently developed by Amazon Web Services (AWS), provides an attractive option for LLM training and inference through its heterogeneous architecture. However, leveraging Trainium architecture for high performance can be challenging because of its systolic array architecture and special requirement on data layout. In this paper, we design high-performance matrix multiplication (matmul), a critical compute kernel, for LLM inference on Trainium. We introduce a series of techniques customized to Trainium based on kernel fusion and novel caching strategies to reduce data movement across the software-managed memory hierarchy, maximize SRAM bandwidth, and avoid expensive matrix transpose. Evaluating with nine datasets and four recent LLMs, we show that our system largely outperforms the state-of-the-art matmul implemented by AWS on Trainium: at the level of matmul kernel, it achieves an average 1.35x speedup (up to 2.22x), which translates to an average 1.66x speedup (up to 2.49x) for end-to-end LLM inference.
