Table of Contents
Fetching ...

Dissecting and Re-architecting 3D NAND Flash PIM Arrays for Efficient Single-Batch Token Generation in LLMs

Yongjoo Jang, Sangwoo Hwang, Hojin Lee, Sangwoo Jung, Donghun Lee, Wonbo Shim, Jaeha Kung

TL;DR

This work tackles the memory and latency challenges of serving large language models by offloading single-batch token generation to a re-architected 3D NAND flash processing-in-memory platform. It introduces an optimized plane configuration and an H-tree bus to enable in-die computation, along with tiling and mapping strategies for MVM workloads in LLMs and a QLC-SLC hybrid KV caching scheme. Experimental results show a 2.4× speedup over four RTX4090 GPUs with vLLM and parity with four A100 GPUs, while incurring only modest latency overhead and maintaining a small die footprint around 4.98 mm^2. The proposed approach promises a cost-effective, scalable pathway to accelerate token generation in LLMs by leveraging high-density 3D NAND PIM within the memory hierarchy.

Abstract

The advancement of large language models has led to models with billions of parameters, significantly increasing memory and compute demands. Serving such models on conventional hardware is challenging due to limited DRAM capacity and high GPU costs. Thus, in this work, we propose offloading the single-batch token generation to a 3D NAND flash processing-in-memory (PIM) device, leveraging its high storage density to overcome the DRAM capacity wall. We explore 3D NAND flash configurations and present a re-architected PIM array with an H-tree network for optimal latency and cell density. Along with the well-chosen PIM array size, we develop operation tiling and mapping methods for LLM layers, achieving a 2.4x speedup over four RTX4090 with vLLM and comparable performance to four A100 with only 4.9% latency overhead. Our detailed area analysis reveals that the proposed 3D NAND flash PIM architecture can be integrated within a 4.98mm2 die area under the memory array, without extra area overhead.

Dissecting and Re-architecting 3D NAND Flash PIM Arrays for Efficient Single-Batch Token Generation in LLMs

TL;DR

This work tackles the memory and latency challenges of serving large language models by offloading single-batch token generation to a re-architected 3D NAND flash processing-in-memory platform. It introduces an optimized plane configuration and an H-tree bus to enable in-die computation, along with tiling and mapping strategies for MVM workloads in LLMs and a QLC-SLC hybrid KV caching scheme. Experimental results show a 2.4× speedup over four RTX4090 GPUs with vLLM and parity with four A100 GPUs, while incurring only modest latency overhead and maintaining a small die footprint around 4.98 mm^2. The proposed approach promises a cost-effective, scalable pathway to accelerate token generation in LLMs by leveraging high-density 3D NAND PIM within the memory hierarchy.

Abstract

The advancement of large language models has led to models with billions of parameters, significantly increasing memory and compute demands. Serving such models on conventional hardware is challenging due to limited DRAM capacity and high GPU costs. Thus, in this work, we propose offloading the single-batch token generation to a 3D NAND flash processing-in-memory (PIM) device, leveraging its high storage density to overcome the DRAM capacity wall. We explore 3D NAND flash configurations and present a re-architected PIM array with an H-tree network for optimal latency and cell density. Along with the well-chosen PIM array size, we develop operation tiling and mapping methods for LLM layers, achieving a 2.4x speedup over four RTX4090 with vLLM and comparable performance to four A100 with only 4.9% latency overhead. Our detailed area analysis reveals that the proposed 3D NAND flash PIM architecture can be integrated within a 4.98mm2 die area under the memory array, without extra area overhead.

Paper Structure

This paper contains 16 sections, 6 equations, 14 figures, 2 tables.

Figures (14)

  • Figure 1: Challenges in LLM token generation: (a) substantial memory requirements and (b) higher token generation latency than summarization (OPT-30B on 4$\times$RTX4090).
  • Figure 2: (a) A hierarchical NAND flash architecture from memory cell arrays to an SSD controller. (b) A plane consists of a 3D memory cell array and peripheral circuits. To activate one of 3D-stacked wordlines (WLs), a staircase region is used to allow multi-layer WL connections. A bitline (BL) and a bitline select (BLS) intersects at a string.
  • Figure 3: (a) A top view and (b) a side view of a 3D NAND flash plane when a page read is performed.
  • Figure 4: Simple example of 3D NAND flash PIM operation.
  • Figure 5: Comparison of time per output token (i.e., TPOT) with OPT-30B between the conventional and the proposed 3D NAND PIM architecture.
  • ...and 9 more figures