Table of Contents
Fetching ...

Annealed Winner-Takes-All for Motion Forecasting

Yihong Xu, Victor Letzelter, Mickaël Chen, Éloi Zablocki, Matthieu Cord

TL;DR

This work tackles the instability and mode-collapse issues of Winner-Takes-All training in multi-hypothesis motion forecasting. By integrating annealed Winner-Takes-All (aWTA) with a softmin-based, temperature-controlled weighting, the authors enable diverse future predictions while using a fixed small set of hypotheses and eliminating post-selection. Across two large real-world datasets and two modern trajectory predictors, aWTA yields consistent improvements in key metrics and exhibits phase-transition dynamics as training progresses. The approach is straightforward to plug into existing transformer-based forecasting models and reduces training/inference complexity, with code released for community use.

Abstract

In autonomous driving, motion prediction aims at forecasting the future trajectories of nearby agents, helping the ego vehicle to anticipate behaviors and drive safely. A key challenge is generating a diverse set of future predictions, commonly addressed using data-driven models with Multiple Choice Learning (MCL) architectures and Winner-Takes-All (WTA) training objectives. However, these methods face initialization sensitivity and training instabilities. Additionally, to compensate for limited performance, some approaches rely on training with a large set of hypotheses, requiring a post-selection step during inference to significantly reduce the number of predictions. To tackle these issues, we take inspiration from annealed MCL, a recently introduced technique that improves the convergence properties of MCL methods through an annealed Winner-Takes-All loss (aWTA). In this paper, we demonstrate how the aWTA loss can be integrated with state-of-the-art motion forecasting models to enhance their performance using only a minimal set of hypotheses, eliminating the need for the cumbersome post-selection step. Our approach can be easily incorporated into any trajectory prediction model normally trained using WTA and yields significant improvements. To facilitate the application of our approach to future motion forecasting models, the code is made publicly available: https://github.com/valeoai/MF_aWTA.

Annealed Winner-Takes-All for Motion Forecasting

TL;DR

This work tackles the instability and mode-collapse issues of Winner-Takes-All training in multi-hypothesis motion forecasting. By integrating annealed Winner-Takes-All (aWTA) with a softmin-based, temperature-controlled weighting, the authors enable diverse future predictions while using a fixed small set of hypotheses and eliminating post-selection. Across two large real-world datasets and two modern trajectory predictors, aWTA yields consistent improvements in key metrics and exhibits phase-transition dynamics as training progresses. The approach is straightforward to plug into existing transformer-based forecasting models and reduces training/inference complexity, with code released for community use.

Abstract

In autonomous driving, motion prediction aims at forecasting the future trajectories of nearby agents, helping the ego vehicle to anticipate behaviors and drive safely. A key challenge is generating a diverse set of future predictions, commonly addressed using data-driven models with Multiple Choice Learning (MCL) architectures and Winner-Takes-All (WTA) training objectives. However, these methods face initialization sensitivity and training instabilities. Additionally, to compensate for limited performance, some approaches rely on training with a large set of hypotheses, requiring a post-selection step during inference to significantly reduce the number of predictions. To tackle these issues, we take inspiration from annealed MCL, a recently introduced technique that improves the convergence properties of MCL methods through an annealed Winner-Takes-All loss (aWTA). In this paper, we demonstrate how the aWTA loss can be integrated with state-of-the-art motion forecasting models to enhance their performance using only a minimal set of hypotheses, eliminating the need for the cumbersome post-selection step. Our approach can be easily incorporated into any trajectory prediction model normally trained using WTA and yields significant improvements. To facilitate the application of our approach to future motion forecasting models, the code is made publicly available: https://github.com/valeoai/MF_aWTA.
Paper Structure (12 sections, 7 equations, 6 figures, 4 tables)

This paper contains 12 sections, 7 equations, 6 figures, 4 tables.

Figures (6)

  • Figure 1: Issues of WTA, and our proposed aWTA. Motion forecasting models trained with WTA and a small number of hypotheses (e.g., 6 hyp.) suffer from mode collapse (a), causing a performance drop. A naive solution to improve performance is to increase the training hypotheses (b), but a cumbersome post-selection is required. Alternatively, aWTA (c) covers more effective modes while consistently using the same minimum number of hypotheses for both training and inference, discarding the post-selection and achieving better performance, for example, -9.41% minADE and -36.67% MissRate with MTR mtr versus (b) after selection.
  • Figure 2: Evolution of mode distribution during training. The predictions are obtained with Wayfomer Wayformer using WTA (left) or aWTA (right) on Waymo Open Motion Dataset (WOMD) wmod. We observe that the effective number of hypotheses increases with the training step with aWTA. The ground-truth trajectories are shown in green.
  • Figure 3: Phase transition with aWTA Loss. Evolution of (averaged) minADE during training of Wayformer Wayformer on Argoverse 2 argoverse2, comparing WTA (blue) and aWTA (red) training setups. A sudden drop in error is observed around epoch $12$, consistent with the expected behavior of the deterministic annealing procedure. Here, we see that the aWTA converges to a better training fit compared to WTA.
  • Figure 4: Qualitative comparison between aWTA and WTA variants. Predictions are shown on Argoverse 2 argoverse2 (rows 1 and 2) and WOMD wmod (rows 3 and 4) for MTR mtr (rows 1 and 3) and Wayformer Wayformer (rows 2 and 4) models. The ground-truth trajectories are shown in green.
  • Figure 5: Impact of the initial temperature value $T_0$ ($x$ axis) on the performance ($y$ axis) for aWTA in two different methods, MTR mtr (in red) and Wayformer Wayformer (in blue). Results are obtained with Argoverse 2 validation set argoverse2. The dotted lines are baselines trained with default WTA loss.
  • ...and 1 more figures