Multi-Agent Model-Based Reinforcement Learning with Joint State-Action Learned Embeddings
Zhizun Wang, David Meger
TL;DR
This work tackles sample-efficient coordination in multi-agent partially observable environments by fusing model-based reinforcement learning with joint state-action representation learning. The proposed MMSA framework incorporates a SALE-augmented world model to generate latent rollouts and enhances both the policy and agent value networks with SALE, all within a centralized training and decentralized execution paradigm. A QMIX-style mixing network aggregates individual Q-values, leveraging rollouts to inform joint decisions, and a unified loss combines KL regularization, reconstruction, TD learning, and SALE prediction with KL balancing. Across SMAC, MAMuJoCo, and Level-Based Foraging benchmarks, MMSA achieves consistent performance gains and ablation studies confirm the necessity of its components, indicating strong potential for scalable, data-efficient multi-agent planning in complex domains.
Abstract
Learning to coordinate many agents in partially observable and highly dynamic environments requires both informative representations and data-efficient training. To address this challenge, we present a novel model-based multi-agent reinforcement learning framework that unifies joint state-action representation learning with imaginative roll-outs. We design a world model trained with variational auto-encoders and augment the model using the state-action learned embedding (SALE). SALE is injected into both the imagination module that forecasts plausible future roll-outs and the joint agent network whose individual action values are combined through a mixing network to estimate the joint action-value function. By coupling imagined trajectories with SALE-based action values, the agents acquire a richer understanding of how their choices influence collective outcomes, leading to improved long-term planning and optimization under limited real-environment interactions. Empirical studies on well-established multi-agent benchmarks, including StarCraft II Micro-Management, Multi-Agent MuJoCo, and Level-Based Foraging challenges, demonstrate consistent gains of our method over baseline algorithms and highlight the effectiveness of joint state-action learned embeddings within a multi-agent model-based paradigm.
