Table of Contents
Fetching ...

UniZero: Generalized and Efficient Planning with Scalable Latent World Models

Yuan Pu, Yazhe Niu, Zhenjie Yang, Jiyuan Ren, Hongsheng Li, Yu Liu

TL;DR

MuZero-style planning excels in short-horizon tasks but struggles with long-term memory and multitask scalability. UniZero introduces a modular transformer-based latent world model that disentangles latent state from implicit history and jointly optimizes the world model and policy, enabling efficient planning in the latent space with a KV memory. Across VisualMatch, Atari multitask benchmarks, Atari 100K, and DMControl, UniZero achieves strong memory-based and standard-task performance, often outperforming strong baselines and displaying robust multi-task transfer. Ablation studies confirm the importance of SimNorm for stability, while decoding targets offer limited gains, supporting a design that emphasizes decision-relevant latent representations and scalable planning.

Abstract

Learning predictive world models is crucial for enhancing the planning capabilities of reinforcement learning (RL) agents. Recently, MuZero-style algorithms, leveraging the value equivalence principle and Monte Carlo Tree Search (MCTS), have achieved superhuman performance in various domains. However, these methods struggle to scale in heterogeneous scenarios with diverse dependencies and task variability. To overcome these limitations, we introduce UniZero, a novel approach that employs a modular transformer-based world model to effectively learn a shared latent space. By concurrently predicting latent dynamics and decision-oriented quantities conditioned on the learned latent history, UniZero enables joint optimization of the long-horizon world model and policy, facilitating broader and more efficient planning in the latent space. We show that UniZero significantly outperforms existing baselines in benchmarks that require long-term memory. Additionally, UniZero demonstrates superior scalability in multitask learning experiments conducted on Atari benchmarks. In standard single-task RL settings, such as Atari and DMControl, UniZero matches or even surpasses the performance of current state-of-the-art methods. Finally, extensive ablation studies and visual analyses validate the effectiveness and scalability of UniZero's design choices. Our code is available at \textcolor{magenta}{https://github.com/opendilab/LightZero}.

UniZero: Generalized and Efficient Planning with Scalable Latent World Models

TL;DR

MuZero-style planning excels in short-horizon tasks but struggles with long-term memory and multitask scalability. UniZero introduces a modular transformer-based latent world model that disentangles latent state from implicit history and jointly optimizes the world model and policy, enabling efficient planning in the latent space with a KV memory. Across VisualMatch, Atari multitask benchmarks, Atari 100K, and DMControl, UniZero achieves strong memory-based and standard-task performance, often outperforming strong baselines and displaying robust multi-task transfer. Ablation studies confirm the importance of SimNorm for stability, while decoding targets offer limited gains, supporting a design that emphasizes decision-relevant latent representations and scalable planning.

Abstract

Learning predictive world models is crucial for enhancing the planning capabilities of reinforcement learning (RL) agents. Recently, MuZero-style algorithms, leveraging the value equivalence principle and Monte Carlo Tree Search (MCTS), have achieved superhuman performance in various domains. However, these methods struggle to scale in heterogeneous scenarios with diverse dependencies and task variability. To overcome these limitations, we introduce UniZero, a novel approach that employs a modular transformer-based world model to effectively learn a shared latent space. By concurrently predicting latent dynamics and decision-oriented quantities conditioned on the learned latent history, UniZero enables joint optimization of the long-horizon world model and policy, facilitating broader and more efficient planning in the latent space. We show that UniZero significantly outperforms existing baselines in benchmarks that require long-term memory. Additionally, UniZero demonstrates superior scalability in multitask learning experiments conducted on Atari benchmarks. In standard single-task RL settings, such as Atari and DMControl, UniZero matches or even surpasses the performance of current state-of-the-art methods. Finally, extensive ablation studies and visual analyses validate the effectiveness and scalability of UniZero's design choices. Our code is available at \textcolor{magenta}{https://github.com/opendilab/LightZero}.
Paper Structure (61 sections, 11 equations, 22 figures, 11 tables, 1 algorithm)

This paper contains 61 sections, 11 equations, 22 figures, 11 tables, 1 algorithm.

Figures (22)

  • Figure 1: Comparison between the UniZero (Ours) and MuZero-style architectures during training and inference. Left: In the MuZero-style architecture, the recursively unrolled latent representation $s^k_{t}$ is tightly entangled with historical information. During training, it solely utilizes the initial observation of the sequence, resulting in inefficient utilization of information (under-utilization). During inference, the recursively predicted latent representation $s^k_{t-k}$ (with $k=2$ for clarity) serves as the root node in MCTS, which is prone to inaccuracies due to accumulated errors. These issues are particularly pronounced in tasks requiring long-term dependency modeling. Right: UniZero employs a modular latent world model comprising an encoder, a unified transformer backbone, and decision/dynamics heads. This design explicitly disentangles latent states from implicit latent history and leverages all observations during training (full-utilization). During inference, the directly encoded latent state $z_t$ is used as the root node. By utilizing a more complete and accessible context $M = (z_{t-H_\text{infer}}, a_{t-H_\text{infer}}, \dots, z_t, a_t)$, UniZero improves prediction accuracy and enables more effective long-term planning in the latent space.
  • Figure 2: Performance Comparison of UniZero and MuZero variants in Pong under approximate MDP and POMDP settings.Left: Results in the MDP setting. Right: Results in the POMDP setting. UniZero consistently outperforms all baselines across both scenarios, highlighting its robustness and adaptability. MuZero w/ SSL achieves superior sample efficiency in the MDP setting but fails to converge in the POMDP setting due to representation entanglement issues. Both MuZero w/ Context and UniZero (RNN) exhibit limited performance in both settings, primarily due to prediction errors caused by incomplete context representation.
  • Figure 3: MCTS in the learned latent space. The process begins with a new observation $o_1$, which is encoded into a latent state $z_1$. This latent state serves as the root node. The previous keys and values of recent memory are retrieved from the transformer's KV Cache $KV_M$. Subsequently, the search tree utilizes the world model to predict the next latent state $\hat{z}$ (which serves as an internal node), reward $\hat{r}$, policy $p$, and value $v$, conditioned on the retrieved KV, recursively. These predictions are used to conduct MCTS, ultimately resulting in an improved policy $\pi$.
  • Figure 4: Performance comparison on VisualMatch with increased memory lengths. MuZero consistently underperformed across all tasks, primarily due to insufficient context information. The performance of SAC-GPT significantly deteriorated as the memory length increased. In contrast, UniZero maintained a high success rate even with extended memory lengths, demonstrating its superior capacity for modeling long-term dependencies.
  • Figure 5: Performance on the Atari 100K. UniZero achieves a higher human-normalized median score compared to MuZero (Reproduced), demonstrating its ability to effectively model short-term dependencies. Detailed scores and curves are available in Appendix \ref{['appendix:atari_100k_curves']}.
  • ...and 17 more figures