Jasmine: A Simple, Performant and Scalable JAX-based World Modeling Codebase
Mihir Mahajan, Alfred Nguyen, Franz Srambical, Stefan Bauer
TL;DR
Jasmine tackles data scarcity in robotics-style domains by providing a scalable, reproducible world-modeling pipeline built on JAX, enabling interactive environment training from unlabeled videos. The approach integrates a video tokenizer, a latent-action model, and a dynamics transformer in the Genie family, with a minimal latent-action prepend modification found essential for faithful CoinRun generations. Through architectural (FFN expansion, shallower nets) and infrastructure (Grain data loader, ArrayRecord, mixed precision, FlashAttention) optimizations, Jasmine achieves orders-of-magnitude faster wall-clock convergence than prior open implementations and demonstrates robust bitwise determinism and scalable sharding. The work further contributes an openly released repository, pretrained checkpoints, curated datasets, and an IDE-interaction dataset to support rigorous benchmarking and reproducibility in world-model research.
Abstract
While world models are increasingly positioned as a pathway to overcoming data scarcity in domains such as robotics, open training infrastructure for world modeling remains nascent. We introduce Jasmine, a performant JAX-based world modeling codebase that scales from single hosts to hundreds of accelerators with minimal code changes. Jasmine achieves an order-of-magnitude faster reproduction of the CoinRun case study compared to prior open implementations, enabled by performance optimizations across data loading, training and checkpointing. The codebase guarantees fully reproducible training and supports diverse sharding configurations. By pairing Jasmine with curated large-scale datasets, we establish infrastructure for rigorous benchmarking pipelines across model families and architectural ablations.
