Knowledge Distillation on Spatial-Temporal Graph Convolutional Network for Traffic Prediction
Mohammad Izadi, Mehran Safayani, Abdolreza Mirzaei
TL;DR
This work tackles the tight real-time constraint of traffic prediction by combining knowledge distillation with pruning to compress an ST-GCN-based model. It introduces a space-time distillationLoss $L_{ ext{STCD}}$ that fuses response-based and hidden-layer learning with temporal and spatial correlation distillation, enabling a lightweight student to approach the teacher's accuracy. A pruning-driven architecture search (Algorithm 1) derives an efficient student that retains only a small fraction of parameters while benefiting from KD during fine-tuning. Experiments on PeMSD7 and PeMSD8 show substantial execution-time reductions (orders of magnitude) with only minor or no loss in predictive performance, validating the approach for real-time traffic forecasting.
Abstract
Efficient real-time traffic prediction is crucial for reducing transportation time. To predict traffic conditions, we employ a spatio-temporal graph neural network (ST-GNN) to model our real-time traffic data as temporal graphs. Despite its capabilities, it often encounters challenges in delivering efficient real-time predictions for real-world traffic data. Recognizing the significance of timely prediction due to the dynamic nature of real-time data, we employ knowledge distillation (KD) as a solution to enhance the execution time of ST-GNNs for traffic prediction. In this paper, We introduce a cost function designed to train a network with fewer parameters (the student) using distilled data from a complex network (the teacher) while maintaining its accuracy close to that of the teacher. We use knowledge distillation, incorporating spatial-temporal correlations from the teacher network to enable the student to learn the complex patterns perceived by the teacher. However, a challenge arises in determining the student network architecture rather than considering it inadvertently. To address this challenge, we propose an algorithm that utilizes the cost function to calculate pruning scores, addressing small network architecture search issues, and jointly fine-tunes the network resulting from each pruning stage using KD. Ultimately, we evaluate our proposed ideas on two real-world datasets, PeMSD7 and PeMSD8. The results indicate that our method can maintain the student's accuracy close to that of the teacher, even with the retention of only 3% of network parameters.
