Table of Contents
Fetching ...

Knowledge Distillation on Spatial-Temporal Graph Convolutional Network for Traffic Prediction

Mohammad Izadi, Mehran Safayani, Abdolreza Mirzaei

TL;DR

This work tackles the tight real-time constraint of traffic prediction by combining knowledge distillation with pruning to compress an ST-GCN-based model. It introduces a space-time distillationLoss $L_{ ext{STCD}}$ that fuses response-based and hidden-layer learning with temporal and spatial correlation distillation, enabling a lightweight student to approach the teacher's accuracy. A pruning-driven architecture search (Algorithm 1) derives an efficient student that retains only a small fraction of parameters while benefiting from KD during fine-tuning. Experiments on PeMSD7 and PeMSD8 show substantial execution-time reductions (orders of magnitude) with only minor or no loss in predictive performance, validating the approach for real-time traffic forecasting.

Abstract

Efficient real-time traffic prediction is crucial for reducing transportation time. To predict traffic conditions, we employ a spatio-temporal graph neural network (ST-GNN) to model our real-time traffic data as temporal graphs. Despite its capabilities, it often encounters challenges in delivering efficient real-time predictions for real-world traffic data. Recognizing the significance of timely prediction due to the dynamic nature of real-time data, we employ knowledge distillation (KD) as a solution to enhance the execution time of ST-GNNs for traffic prediction. In this paper, We introduce a cost function designed to train a network with fewer parameters (the student) using distilled data from a complex network (the teacher) while maintaining its accuracy close to that of the teacher. We use knowledge distillation, incorporating spatial-temporal correlations from the teacher network to enable the student to learn the complex patterns perceived by the teacher. However, a challenge arises in determining the student network architecture rather than considering it inadvertently. To address this challenge, we propose an algorithm that utilizes the cost function to calculate pruning scores, addressing small network architecture search issues, and jointly fine-tunes the network resulting from each pruning stage using KD. Ultimately, we evaluate our proposed ideas on two real-world datasets, PeMSD7 and PeMSD8. The results indicate that our method can maintain the student's accuracy close to that of the teacher, even with the retention of only 3% of network parameters.

Knowledge Distillation on Spatial-Temporal Graph Convolutional Network for Traffic Prediction

TL;DR

This work tackles the tight real-time constraint of traffic prediction by combining knowledge distillation with pruning to compress an ST-GCN-based model. It introduces a space-time distillationLoss that fuses response-based and hidden-layer learning with temporal and spatial correlation distillation, enabling a lightweight student to approach the teacher's accuracy. A pruning-driven architecture search (Algorithm 1) derives an efficient student that retains only a small fraction of parameters while benefiting from KD during fine-tuning. Experiments on PeMSD7 and PeMSD8 show substantial execution-time reductions (orders of magnitude) with only minor or no loss in predictive performance, validating the approach for real-time traffic forecasting.

Abstract

Efficient real-time traffic prediction is crucial for reducing transportation time. To predict traffic conditions, we employ a spatio-temporal graph neural network (ST-GNN) to model our real-time traffic data as temporal graphs. Despite its capabilities, it often encounters challenges in delivering efficient real-time predictions for real-world traffic data. Recognizing the significance of timely prediction due to the dynamic nature of real-time data, we employ knowledge distillation (KD) as a solution to enhance the execution time of ST-GNNs for traffic prediction. In this paper, We introduce a cost function designed to train a network with fewer parameters (the student) using distilled data from a complex network (the teacher) while maintaining its accuracy close to that of the teacher. We use knowledge distillation, incorporating spatial-temporal correlations from the teacher network to enable the student to learn the complex patterns perceived by the teacher. However, a challenge arises in determining the student network architecture rather than considering it inadvertently. To address this challenge, we propose an algorithm that utilizes the cost function to calculate pruning scores, addressing small network architecture search issues, and jointly fine-tunes the network resulting from each pruning stage using KD. Ultimately, we evaluate our proposed ideas on two real-world datasets, PeMSD7 and PeMSD8. The results indicate that our method can maintain the student's accuracy close to that of the teacher, even with the retention of only 3% of network parameters.
Paper Structure (18 sections, 12 equations, 4 figures, 8 tables, 1 algorithm)

This paper contains 18 sections, 12 equations, 4 figures, 8 tables, 1 algorithm.

Figures (4)

  • Figure 1: The illustration features both our student and teacher models, demonstrating the application of our cost functions to the spatio-temporal graph convolutional network (ST-GxCN)
  • Figure 2: Sequence of ST-GCN predictions for future time steps. In each step, the output of the current state is used as the last input graph to predict the next timestep.
  • Figure 3: Comparison of our spatial-temporal correlation distillation loss function $L_{\text{STCD}}$ with the $L_{\text{RD(L2)}}$, $L_{\text{RD(KL)}}$ and $L_{\text{SKD}}$ loss functions. Subfigure \ref{['subfig:PeMSD7']} shows results for PeMSD7 dataset, and Subfigure \ref{['subfig:PeMSD8']} shows for PeMSD8. The left chart in the top row corresponds to the teacher, while the middle chart in the top row pertains to a student without knowledge distillation. In the second row, the right chart represents our loss function, and the other charts indicate different loss functions.
  • Figure 4: The results highlight the impact of employing the pruning algorithm, demonstrating the average predicted value for a randomly selected node based on 50 data points.