Closed-Loop Supervised Fine-Tuning of Tokenized Traffic Models
Zhejun Zhang, Peter Karkus, Maximilian Igl, Wenhao Ding, Yuxiao Chen, Boris Ivanovic, Marco Pavone
TL;DR
This work tackles covariate shift and multimodality in tokenized, multi-agent traffic policies when moving from open-loop training to closed-loop evaluation. It introduces Closest Among Top-K (CAT-K) rollouts, which greedily select among the top-$K$ policy actions by minimizing the distance to ground-truth next states during fine-tuning, enabling closed-loop supervised training without reinforcement learning. A two-stage pipeline of behavior cloning (BC) pre-training followed by CAT-K closed-loop fine-tuning yields a compact 7M-parameter SMART policy that outperforms a 102M-parameter baseline and achieves state-of-the-art performance on the Waymo Open Sim Agent Challenge. The approach also improves a Gaussian Mixture Model ego-policy in an ego-vehicle task, indicating broad applicability to multimodal imitation learning across discrete token and continuous action spaces.
Abstract
Traffic simulation aims to learn a policy for traffic agents that, when unrolled in closed-loop, faithfully recovers the joint distribution of trajectories observed in the real world. Inspired by large language models, tokenized multi-agent policies have recently become the state-of-the-art in traffic simulation. However, they are typically trained through open-loop behavior cloning, and thus suffer from covariate shift when executed in closed-loop during simulation. In this work, we present Closest Among Top-K (CAT-K) rollouts, a simple yet effective closed-loop fine-tuning strategy to mitigate covariate shift. CAT-K fine-tuning only requires existing trajectory data, without reinforcement learning or generative adversarial imitation. Concretely, CAT-K fine-tuning enables a small 7M-parameter tokenized traffic simulation policy to outperform a 102M-parameter model from the same model family, achieving the top spot on the Waymo Sim Agent Challenge leaderboard at the time of submission. The code is available at https://github.com/NVlabs/catk.
