RTop-K: Ultra-Fast Row-Wise Top-K Selection for Neural Network Acceleration on GPUs
Xi Xie, Yuebo Luo, Hongwu Peng, Caiwen Ding
TL;DR
This work targets the challenge of efficient row-wise top-$k$ selection on GPUs for neural network workloads, especially MaxK-GNNs. It introduces RTop-K, a binary search-based top-$k$ algorithm that operates per row with optional early stopping to balance speed and accuracy, and it provides theoretical and empirical analysis of the iteration behavior. The GPU kernel design features a three-stage pipeline (load, search, select) and leverages warp-level primitives to minimize memory traffic while producing exactly (or approximately) the top-$k$ elements per row. Empirical results show substantial end-to-end gains, with average kernel speedups up to $11.49\times$ (and $7.29\times$ without early stopping) against PyTorch, and overall MaxK-GNN training speedups from roughly $12\%$ to $33\%$ across multiple models and datasets, while maintaining robust accuracy under early stopping.
Abstract
Top-k selection algorithms are fundamental in a wide range of applications, including high-performance computing, information retrieval, big data processing, and neural network model training. In this paper, we present RTop-K, a highly efficient parallel row-wise top-k selection algorithm specifically designed for GPUs. RTop-K leverages a binary search-based approach to optimize row-wise top-k selection, providing a scalable and accelerated solution. We conduct a detailed analysis of early stopping in our algorithm, showing that it effectively maintains the testing accuracy of neural network models while substantially improving performance. Our GPU implementation of RTop-K demonstrates superior performance over state-of-the-art row-wise top-k GPU implementations, achieving an average speed-up of up to 11.49$\times$ with early stopping and 7.29$\times$ without early stopping. Moreover, RTop-K accelerates the overall training workflow of MaxK-GNNs, delivering speed-ups ranging from 11.97% to 33.29% across different models and datasets.
