Exploiting Student Parallelism for Efficient GPU Inference of BERT-like Models in Online Services
Weiyan Wang, Yilun Jin, Yiming Zhang, Victor Junqiu Wei, Han Tian, Li Chen, Jinbao Xue, Yangyu Tao, Di Wang, Kai Chen
TL;DR
This work tackles the latency bottleneck of deploying large BERT-like models in online services by introducing Student Parallelism, which distills a deep teacher into a group of shallow, parallel students trained via stacking distillation and boosting. Adaptive pruning allows dynamic adjustment of the number of students to handle workload bursts without substantial accuracy loss, while online inference leverages multi-GPU deployment, MPS sharing, and a length-aware buffer to minimize waiting and padding. The approach achieves state-of-the-art latency reductions and throughput improvements, enabling compression of deep models to very shallow depths without sacrificing accuracy. The results demonstrate practical viability for real-world online NLP services that must handle stochastic, length-variable workloads with tight latency constraints.
Abstract
Due to high accuracy, BERT-like models have been widely adopted by text mining and web searching. However, large BERT-like models suffer from inefficient online inference, facing the following two problems on GPUs: (1) their high accuracy relies on the large model depth, which linearly increases the sequential computation on GPUs; (2) stochastic and dynamic online workloads cause extra costs from batching and paddings. Therefore, we present \sys for the real-world setting of GPU inference on online workloads. At its core, \sys adopts stacking distillation and boosting ensemble, distilling the original deep model into a group of shallow but virtually stacked student models running in parallel. This enables \sys to achieve a lower model depth (e.g., two layers) than the others and the lowest inference latency while maintaining accuracy. In addition, adaptive student pruning realizes dynamic student numbers according to changing online workloads. Especially for occasional workload bursts, it can temporarily decrease the student number with minimal accuracy loss to improve system throughput. We conduct comprehensive experiments to verify the effectiveness, whose results show that \sys outperforms the baselines by $4.1\times\sim 1.6\times$ in latency while maintaining accuracy and achieves up to $22.27\times$ higher throughput for workload bursts.
