Dissecting and Re-architecting 3D NAND Flash PIM Arrays for Efficient Single-Batch Token Generation in LLMs
Yongjoo Jang, Sangwoo Hwang, Hojin Lee, Sangwoo Jung, Donghun Lee, Wonbo Shim, Jaeha Kung
TL;DR
This work tackles the memory and latency challenges of serving large language models by offloading single-batch token generation to a re-architected 3D NAND flash processing-in-memory platform. It introduces an optimized plane configuration and an H-tree bus to enable in-die computation, along with tiling and mapping strategies for MVM workloads in LLMs and a QLC-SLC hybrid KV caching scheme. Experimental results show a 2.4× speedup over four RTX4090 GPUs with vLLM and parity with four A100 GPUs, while incurring only modest latency overhead and maintaining a small die footprint around 4.98 mm^2. The proposed approach promises a cost-effective, scalable pathway to accelerate token generation in LLMs by leveraging high-density 3D NAND PIM within the memory hierarchy.
Abstract
The advancement of large language models has led to models with billions of parameters, significantly increasing memory and compute demands. Serving such models on conventional hardware is challenging due to limited DRAM capacity and high GPU costs. Thus, in this work, we propose offloading the single-batch token generation to a 3D NAND flash processing-in-memory (PIM) device, leveraging its high storage density to overcome the DRAM capacity wall. We explore 3D NAND flash configurations and present a re-architected PIM array with an H-tree network for optimal latency and cell density. Along with the well-chosen PIM array size, we develop operation tiling and mapping methods for LLM layers, achieving a 2.4x speedup over four RTX4090 with vLLM and comparable performance to four A100 with only 4.9% latency overhead. Our detailed area analysis reveals that the proposed 3D NAND flash PIM architecture can be integrated within a 4.98mm2 die area under the memory array, without extra area overhead.
