Enhancing next token prediction based pre-training for jet foundation models
Joschka Birk, Anna Hallin, Gregor Kasieczka, Nikol Madzharova, Ian Pang, David Shih
TL;DR
This work addresses the limitation of token-based inputs in OmniJet-α by introducing a hybrid continuous-input transformer that uses continuous per-particle features for both generation and classification, while keeping token IDs as targets for next-token prediction. It further enhances pre-training with masked token prediction (MPM) and a joint NTP+MPM objective, showing that MPM-based strategies yield superior downstream jet classification without harming generative fidelity. Experiments on the JetClass and top tagging datasets demonstrate that continuous inputs significantly improve classification, and joint NTP+MPM achieves near-MPM performance with robust generation. The findings suggest a promising path for data-efficient, simulation-free jet foundation models that balance generation and discriminative tasks, and point toward exploring alternative tokenizations and physical priors.
Abstract
Next token prediction is an attractive pre-training task for jet foundation models, in that it is simulation free and enables excellent generative capabilities that can transfer across datasets. Here we study multiple improvements to next token prediction, building on the initial work of OmniJet-$α$. Instead of tokenizing particles and subsequently only using the token-ID as the model input for both the generative and the classification task, we adopt a hybrid setup, which allows us to use continuous feature vectors as model input while only using token-IDs in the next token prediction target. Secondly, we explore a combined pre-training strategy that combines masked particle modeling and generative learning objectives. Taken together, these changes greatly improve the performance in downstream classification tasks without any loss in generative performance.
