Table of Contents
Fetching ...

Jasmine: A Simple, Performant and Scalable JAX-based World Modeling Codebase

Mihir Mahajan, Alfred Nguyen, Franz Srambical, Stefan Bauer

TL;DR

Jasmine tackles data scarcity in robotics-style domains by providing a scalable, reproducible world-modeling pipeline built on JAX, enabling interactive environment training from unlabeled videos. The approach integrates a video tokenizer, a latent-action model, and a dynamics transformer in the Genie family, with a minimal latent-action prepend modification found essential for faithful CoinRun generations. Through architectural (FFN expansion, shallower nets) and infrastructure (Grain data loader, ArrayRecord, mixed precision, FlashAttention) optimizations, Jasmine achieves orders-of-magnitude faster wall-clock convergence than prior open implementations and demonstrates robust bitwise determinism and scalable sharding. The work further contributes an openly released repository, pretrained checkpoints, curated datasets, and an IDE-interaction dataset to support rigorous benchmarking and reproducibility in world-model research.

Abstract

While world models are increasingly positioned as a pathway to overcoming data scarcity in domains such as robotics, open training infrastructure for world modeling remains nascent. We introduce Jasmine, a performant JAX-based world modeling codebase that scales from single hosts to hundreds of accelerators with minimal code changes. Jasmine achieves an order-of-magnitude faster reproduction of the CoinRun case study compared to prior open implementations, enabled by performance optimizations across data loading, training and checkpointing. The codebase guarantees fully reproducible training and supports diverse sharding configurations. By pairing Jasmine with curated large-scale datasets, we establish infrastructure for rigorous benchmarking pipelines across model families and architectural ablations.

Jasmine: A Simple, Performant and Scalable JAX-based World Modeling Codebase

TL;DR

Jasmine tackles data scarcity in robotics-style domains by providing a scalable, reproducible world-modeling pipeline built on JAX, enabling interactive environment training from unlabeled videos. The approach integrates a video tokenizer, a latent-action model, and a dynamics transformer in the Genie family, with a minimal latent-action prepend modification found essential for faithful CoinRun generations. Through architectural (FFN expansion, shallower nets) and infrastructure (Grain data loader, ArrayRecord, mixed precision, FlashAttention) optimizations, Jasmine achieves orders-of-magnitude faster wall-clock convergence than prior open implementations and demonstrates robust bitwise determinism and scalable sharding. The work further contributes an openly released repository, pretrained checkpoints, curated datasets, and an IDE-interaction dataset to support rigorous benchmarking and reproducibility in world-model research.

Abstract

While world models are increasingly positioned as a pathway to overcoming data scarcity in domains such as robotics, open training infrastructure for world modeling remains nascent. We introduce Jasmine, a performant JAX-based world modeling codebase that scales from single hosts to hundreds of accelerators with minimal code changes. Jasmine achieves an order-of-magnitude faster reproduction of the CoinRun case study compared to prior open implementations, enabled by performance optimizations across data loading, training and checkpointing. The codebase guarantees fully reproducible training and supports diverse sharding configurations. By pairing Jasmine with curated large-scale datasets, we establish infrastructure for rigorous benchmarking pipelines across model families and architectural ablations.

Paper Structure

This paper contains 20 sections, 12 figures, 6 tables.

Figures (12)

  • Figure 1: Autoregressive sampling of Jafar willi2024jafar (middle row) and Jasmine (bottom row) on the CoinRun case study with four conditioning frames (conditioning frames not shown). The top row shows the ground-truth sequence.
  • Figure 2: Validation metrics of the CoinRun case study (patch size 4). While the loss (left) is similar between the default Genie configuration and our minimal modification, rollout metrics (middle and right, refer to \ref{['sec:experiment-metrics']}) differ substantially.
  • Figure 3: An order of magnitude faster convergence in wall-clock time in Jasmine (blue) compared to Jafar willi2024jafar (orange). We report the train loss since Jafar does not collect validation metrics. Refer to Appendix \ref{['sec:arch-ablations']} for Jasmine's validation metrics. Jasmine's lower variance stems from a subtle refinement in its batched masking logic (Appendix \ref{['sec:jafar-batched-masking']}).
  • Figure 4: Autoregressive sampling of Jasmine when adding (middle row) and prepending actions (bottom row) on the CoinRun case study with four conditioning frames (conditioning frames not shown). The top row shows the ground-truth sequence.
  • Figure 5: Architectural ablations of Jasmine's base configuration (refer to \ref{['tab:hparams']}) on CoinRun. We report loss (left) and rollout metrics (middle and right) of the dynamics model on a validation set.
  • ...and 7 more figures