Table of Contents
Fetching ...

Enhancing next token prediction based pre-training for jet foundation models

Joschka Birk, Anna Hallin, Gregor Kasieczka, Nikol Madzharova, Ian Pang, David Shih

TL;DR

This work addresses the limitation of token-based inputs in OmniJet-α by introducing a hybrid continuous-input transformer that uses continuous per-particle features for both generation and classification, while keeping token IDs as targets for next-token prediction. It further enhances pre-training with masked token prediction (MPM) and a joint NTP+MPM objective, showing that MPM-based strategies yield superior downstream jet classification without harming generative fidelity. Experiments on the JetClass and top tagging datasets demonstrate that continuous inputs significantly improve classification, and joint NTP+MPM achieves near-MPM performance with robust generation. The findings suggest a promising path for data-efficient, simulation-free jet foundation models that balance generation and discriminative tasks, and point toward exploring alternative tokenizations and physical priors.

Abstract

Next token prediction is an attractive pre-training task for jet foundation models, in that it is simulation free and enables excellent generative capabilities that can transfer across datasets. Here we study multiple improvements to next token prediction, building on the initial work of OmniJet-$α$. Instead of tokenizing particles and subsequently only using the token-ID as the model input for both the generative and the classification task, we adopt a hybrid setup, which allows us to use continuous feature vectors as model input while only using token-IDs in the next token prediction target. Secondly, we explore a combined pre-training strategy that combines masked particle modeling and generative learning objectives. Taken together, these changes greatly improve the performance in downstream classification tasks without any loss in generative performance.

Enhancing next token prediction based pre-training for jet foundation models

TL;DR

This work addresses the limitation of token-based inputs in OmniJet-α by introducing a hybrid continuous-input transformer that uses continuous per-particle features for both generation and classification, while keeping token IDs as targets for next-token prediction. It further enhances pre-training with masked token prediction (MPM) and a joint NTP+MPM objective, showing that MPM-based strategies yield superior downstream jet classification without harming generative fidelity. Experiments on the JetClass and top tagging datasets demonstrate that continuous inputs significantly improve classification, and joint NTP+MPM achieves near-MPM performance with robust generation. The findings suggest a promising path for data-efficient, simulation-free jet foundation models that balance generation and discriminative tasks, and point toward exploring alternative tokenizations and physical priors.

Abstract

Next token prediction is an attractive pre-training task for jet foundation models, in that it is simulation free and enables excellent generative capabilities that can transfer across datasets. Here we study multiple improvements to next token prediction, building on the initial work of OmniJet-. Instead of tokenizing particles and subsequently only using the token-ID as the model input for both the generative and the classification task, we adopt a hybrid setup, which allows us to use continuous feature vectors as model input while only using token-IDs in the next token prediction target. Secondly, we explore a combined pre-training strategy that combines masked particle modeling and generative learning objectives. Taken together, these changes greatly improve the performance in downstream classification tasks without any loss in generative performance.

Paper Structure

This paper contains 21 sections, 12 figures, 2 tables.

Figures (12)

  • Figure 1: The high-level architecture of the OmniJet workflow: (a) the pre-training based on next token prediction, (b) the generative model obtained from unsupervised pre-training and (c) the classification model obtained from supervised fine-tuning. Our new approach uses continuous feature vectors as input, both for the generative and classification tasks. Token-IDs are shown as $t_i$, with $t_s$ corresponding to a start token and $t_e$ corresponding to an end token. The continuous feature vectors are shown as $\vec{c}_i$, and the pseudo-continuous feature vectors (i.e. the decoded token-IDs) are shown as $\vec{d}_i$. In the continuous input case, the start token is a trainable embedding $\vec{d}_s$.
  • Figure 2: Schematic overview of the different pre-training strategies: (a) next token prediction (NTP), (b) masked token prediction (MPM) Golling:2024abgLeigh:2024ked, and (c) joint NTP and MPM. Token-IDs are shown as $t_i$, continuous feature vectors as $\vec{c}_i$, pseudo-continuous feature vectors as $\vec{d}_i$, the mask embeddings as $\vec{m}_i$. The vectors $\vec{e}_i$ represent the backbone output and $\vec{n}_i$ and $\vec{y}_i$ represent the next and masked token predictions, respectively. The particles are not $p_\mathrm{T}$-sorted, but with positional encoding the information of the $p_\mathrm{T}$-order within the masked subset is recovered: in this example, the first masked particle (position 1) is assumed to have smaller $p_\mathrm{T}$ than the second masked particle (position 4), which is why they are re-introduced as $\vec{m}_2$ and $\vec{m}_1$ respectively when using positional encoding.
  • Figure 3: Comparison of the jets that are generated by the token-ID-input and the continuous-input model, as well as the jets from the JetClass dataset on particle-level (top row) and jet-level (bottom row). The first and last bins show the under- and overflow bins, respectively. The jets shown for the JetClass dataset are tokenized and subsequently decoded, representing the target of the generative model.
  • Figure 4: Comparison of token-ID input vs. continuous input classification performance: (a) for multi-class classification performance on the JetClass dataset (all 10 jet types, in-distribution transfer learning) and (b) for binary classification performance on the top tagging (out-of-distribution transfer learning) as a function of the number of training jets.
  • Figure 5: Classification performance on the top tagging dataset with different pre-training strategies: (a) the classifier accuracy and (b) the background ($q/g$-jet) rejection at 30 % signal (top-jet) efficiency. All models use continuous feature vectors as input.
  • ...and 7 more figures