Table of Contents
Fetching ...

Scaling Motion Forecasting Models with Ensemble Distillation

Scott Ettinger, Kratarth Goel, Avikalp Srivastava, Rami Al-Rfou

TL;DR

The paper addresses the challenge of achieving high-accuracy motion forecasting for autonomous systems under limited onboard compute. It introduces a general ensemble distillation framework that first constructs a large, diverse ensemble of motion forecasting models and then distills their multi-modal outputs into a compact student model, preserving accuracy while reducing compute. Empirical results on the Waymo Open Motion Dataset and Argoverse 2 show that ensembles scale performance with compute, achieving podium-level standings, while distilled students retain much of that accuracy at a fraction of the FLOPs. The approach enables real-time, high-quality trajectory prediction for robotics and autonomous driving within strict hardware budgets, offering practical benefits for safety-critical planning and robustness in dynamic scenes.

Abstract

Motion forecasting has become an increasingly critical component of autonomous robotic systems. Onboard compute budgets typically limit the accuracy of real-time systems. In this work we propose methods of improving motion forecasting systems subject to limited compute budgets by combining model ensemble and distillation techniques. The use of ensembles of deep neural networks has been shown to improve generalization accuracy in many application domains. We first demonstrate significant performance gains by creating a large ensemble of optimized single models. We then develop a generalized framework to distill motion forecasting model ensembles into small student models which retain high performance with a fraction of the computing cost. For this study we focus on the task of motion forecasting using real world data from autonomous driving systems. We develop ensemble models that are very competitive on the Waymo Open Motion Dataset (WOMD) and Argoverse leaderboards. From these ensembles, we train distilled student models which have high performance at a fraction of the compute costs. These experiments demonstrate distillation from ensembles as an effective method for improving accuracy of predictive models for robotic systems with limited compute budgets.

Scaling Motion Forecasting Models with Ensemble Distillation

TL;DR

The paper addresses the challenge of achieving high-accuracy motion forecasting for autonomous systems under limited onboard compute. It introduces a general ensemble distillation framework that first constructs a large, diverse ensemble of motion forecasting models and then distills their multi-modal outputs into a compact student model, preserving accuracy while reducing compute. Empirical results on the Waymo Open Motion Dataset and Argoverse 2 show that ensembles scale performance with compute, achieving podium-level standings, while distilled students retain much of that accuracy at a fraction of the FLOPs. The approach enables real-time, high-quality trajectory prediction for robotics and autonomous driving within strict hardware budgets, offering practical benefits for safety-critical planning and robustness in dynamic scenes.

Abstract

Motion forecasting has become an increasingly critical component of autonomous robotic systems. Onboard compute budgets typically limit the accuracy of real-time systems. In this work we propose methods of improving motion forecasting systems subject to limited compute budgets by combining model ensemble and distillation techniques. The use of ensembles of deep neural networks has been shown to improve generalization accuracy in many application domains. We first demonstrate significant performance gains by creating a large ensemble of optimized single models. We then develop a generalized framework to distill motion forecasting model ensembles into small student models which retain high performance with a fraction of the computing cost. For this study we focus on the task of motion forecasting using real world data from autonomous driving systems. We develop ensemble models that are very competitive on the Waymo Open Motion Dataset (WOMD) and Argoverse leaderboards. From these ensembles, we train distilled student models which have high performance at a fraction of the compute costs. These experiments demonstrate distillation from ensembles as an effective method for improving accuracy of predictive models for robotic systems with limited compute budgets.
Paper Structure (24 sections, 7 equations, 14 figures, 2 tables)

This paper contains 24 sections, 7 equations, 14 figures, 2 tables.

Figures (14)

  • Figure 1: Ensembles provide a method to improve both the soft-mAP and the minADE metrics for trajectory prediction with a linear increase in compute. But can we do better and achieve similar scaling of performance without the additional compute cost?
  • Figure 2: The Wayformer architecture is a pair of encoder/decoder Transformer networks. This model takes multimodal scene data as input and produces a multimodal distribution of trajectories.
  • Figure 3: Illustration of the ensemble distillation pipeline. A set of $K$ teachers and a NMS that outputs $M_T$ trajectories form the ensemble. The student is trained with a groundtruth loss ($\mathcal{L}_{gt}$) and a distillation loss ($\mathcal{L}_{Distill}$).
  • Figure 4: We compare the metrics for Waymo Open Motion Dataset for our ensemble models and the distilled student models as a function of inference FLOPs (total floating point operations) relative to the single Wayformer Nayakanti2022Wayformer single head model. The relative FLOPs on the x-axis are shown in log scale. The Ensemble models are represented in orange, with ensemble size $K$ linearly related to the FLOPs. Distilled student models are represented in green. The blue line represents the ensembling SoTA results and non-ensemble SoTA is represented by the red line.
  • Figure 5: Scaling ensemble result using the Argoverse dataset. The Ensemble models are represented in orange, with ensemble size $K$ linearly related to the FLOPs. Distilled student models are represented in green.
  • ...and 9 more figures