Table of Contents
Fetching ...

Exploiting Student Parallelism for Efficient GPU Inference of BERT-like Models in Online Services

Weiyan Wang, Yilun Jin, Yiming Zhang, Victor Junqiu Wei, Han Tian, Li Chen, Jinbao Xue, Yangyu Tao, Di Wang, Kai Chen

TL;DR

This work tackles the latency bottleneck of deploying large BERT-like models in online services by introducing Student Parallelism, which distills a deep teacher into a group of shallow, parallel students trained via stacking distillation and boosting. Adaptive pruning allows dynamic adjustment of the number of students to handle workload bursts without substantial accuracy loss, while online inference leverages multi-GPU deployment, MPS sharing, and a length-aware buffer to minimize waiting and padding. The approach achieves state-of-the-art latency reductions and throughput improvements, enabling compression of deep models to very shallow depths without sacrificing accuracy. The results demonstrate practical viability for real-world online NLP services that must handle stochastic, length-variable workloads with tight latency constraints.

Abstract

Due to high accuracy, BERT-like models have been widely adopted by text mining and web searching. However, large BERT-like models suffer from inefficient online inference, facing the following two problems on GPUs: (1) their high accuracy relies on the large model depth, which linearly increases the sequential computation on GPUs; (2) stochastic and dynamic online workloads cause extra costs from batching and paddings. Therefore, we present \sys for the real-world setting of GPU inference on online workloads. At its core, \sys adopts stacking distillation and boosting ensemble, distilling the original deep model into a group of shallow but virtually stacked student models running in parallel. This enables \sys to achieve a lower model depth (e.g., two layers) than the others and the lowest inference latency while maintaining accuracy. In addition, adaptive student pruning realizes dynamic student numbers according to changing online workloads. Especially for occasional workload bursts, it can temporarily decrease the student number with minimal accuracy loss to improve system throughput. We conduct comprehensive experiments to verify the effectiveness, whose results show that \sys outperforms the baselines by $4.1\times\sim 1.6\times$ in latency while maintaining accuracy and achieves up to $22.27\times$ higher throughput for workload bursts.

Exploiting Student Parallelism for Efficient GPU Inference of BERT-like Models in Online Services

TL;DR

This work tackles the latency bottleneck of deploying large BERT-like models in online services by introducing Student Parallelism, which distills a deep teacher into a group of shallow, parallel students trained via stacking distillation and boosting. Adaptive pruning allows dynamic adjustment of the number of students to handle workload bursts without substantial accuracy loss, while online inference leverages multi-GPU deployment, MPS sharing, and a length-aware buffer to minimize waiting and padding. The approach achieves state-of-the-art latency reductions and throughput improvements, enabling compression of deep models to very shallow depths without sacrificing accuracy. The results demonstrate practical viability for real-world online NLP services that must handle stochastic, length-variable workloads with tight latency constraints.

Abstract

Due to high accuracy, BERT-like models have been widely adopted by text mining and web searching. However, large BERT-like models suffer from inefficient online inference, facing the following two problems on GPUs: (1) their high accuracy relies on the large model depth, which linearly increases the sequential computation on GPUs; (2) stochastic and dynamic online workloads cause extra costs from batching and paddings. Therefore, we present \sys for the real-world setting of GPU inference on online workloads. At its core, \sys adopts stacking distillation and boosting ensemble, distilling the original deep model into a group of shallow but virtually stacked student models running in parallel. This enables \sys to achieve a lower model depth (e.g., two layers) than the others and the lowest inference latency while maintaining accuracy. In addition, adaptive student pruning realizes dynamic student numbers according to changing online workloads. Especially for occasional workload bursts, it can temporarily decrease the student number with minimal accuracy loss to improve system throughput. We conduct comprehensive experiments to verify the effectiveness, whose results show that \sys outperforms the baselines by in latency while maintaining accuracy and achieves up to higher throughput for workload bursts.
Paper Structure (10 sections, 1 theorem, 8 equations, 5 figures, 1 table, 2 algorithms)

This paper contains 10 sections, 1 theorem, 8 equations, 5 figures, 1 table, 2 algorithms.

Key Result

theorem 1

The MSE in Student Parallelism has the lower bound 0, and it is a Lipschitz differentiable loss function (for any L > 1, we always have ). Let $F^{(0)},F^{(1)},...$ be the sequence of combined hypotheses generated by the Student Parallelism training algorithm, using small enough step-sizes $\alpha_{i}:=-\frac{\langle \bigtriangledown L_{\text{boost}}(T, B^{(i-1)}),S^{(i)} \rangle}{L$S^(i)$^2}$. T

Figures (5)

  • Figure 1: Influences of different model and input factors on latency in red solid line) and throughput (green dashed line)
  • Figure 2: Comparison in the trade-off between accuracy and latency (the upper left corner is optimal)
  • Figure 3: Virtually stacked students: the intermediate layer imitates all the previous students and the top layer reduces the residual error.
  • Figure 4: The online inference overview and procedure of student parallelism: student allocation and task dispatching
  • Figure 5: Length-aware Buffer: New sequence sample (e.g., the red one whose length is 30) can be merged with the previous element into a small batch, since they share the same bin and the element size is not full. Otherwise, the new sequence sample (e.g., the blue one of 52 tokens cannot join in the full element) is added to the buffer tail as a new element.

Theorems & Definitions (1)

  • theorem 1