Table of Contents
Fetching ...

Multi-Agent Model-Based Reinforcement Learning with Joint State-Action Learned Embeddings

Zhizun Wang, David Meger

TL;DR

This work tackles sample-efficient coordination in multi-agent partially observable environments by fusing model-based reinforcement learning with joint state-action representation learning. The proposed MMSA framework incorporates a SALE-augmented world model to generate latent rollouts and enhances both the policy and agent value networks with SALE, all within a centralized training and decentralized execution paradigm. A QMIX-style mixing network aggregates individual Q-values, leveraging rollouts to inform joint decisions, and a unified loss combines KL regularization, reconstruction, TD learning, and SALE prediction with KL balancing. Across SMAC, MAMuJoCo, and Level-Based Foraging benchmarks, MMSA achieves consistent performance gains and ablation studies confirm the necessity of its components, indicating strong potential for scalable, data-efficient multi-agent planning in complex domains.

Abstract

Learning to coordinate many agents in partially observable and highly dynamic environments requires both informative representations and data-efficient training. To address this challenge, we present a novel model-based multi-agent reinforcement learning framework that unifies joint state-action representation learning with imaginative roll-outs. We design a world model trained with variational auto-encoders and augment the model using the state-action learned embedding (SALE). SALE is injected into both the imagination module that forecasts plausible future roll-outs and the joint agent network whose individual action values are combined through a mixing network to estimate the joint action-value function. By coupling imagined trajectories with SALE-based action values, the agents acquire a richer understanding of how their choices influence collective outcomes, leading to improved long-term planning and optimization under limited real-environment interactions. Empirical studies on well-established multi-agent benchmarks, including StarCraft II Micro-Management, Multi-Agent MuJoCo, and Level-Based Foraging challenges, demonstrate consistent gains of our method over baseline algorithms and highlight the effectiveness of joint state-action learned embeddings within a multi-agent model-based paradigm.

Multi-Agent Model-Based Reinforcement Learning with Joint State-Action Learned Embeddings

TL;DR

This work tackles sample-efficient coordination in multi-agent partially observable environments by fusing model-based reinforcement learning with joint state-action representation learning. The proposed MMSA framework incorporates a SALE-augmented world model to generate latent rollouts and enhances both the policy and agent value networks with SALE, all within a centralized training and decentralized execution paradigm. A QMIX-style mixing network aggregates individual Q-values, leveraging rollouts to inform joint decisions, and a unified loss combines KL regularization, reconstruction, TD learning, and SALE prediction with KL balancing. Across SMAC, MAMuJoCo, and Level-Based Foraging benchmarks, MMSA achieves consistent performance gains and ablation studies confirm the necessity of its components, indicating strong potential for scalable, data-efficient multi-agent planning in complex domains.

Abstract

Learning to coordinate many agents in partially observable and highly dynamic environments requires both informative representations and data-efficient training. To address this challenge, we present a novel model-based multi-agent reinforcement learning framework that unifies joint state-action representation learning with imaginative roll-outs. We design a world model trained with variational auto-encoders and augment the model using the state-action learned embedding (SALE). SALE is injected into both the imagination module that forecasts plausible future roll-outs and the joint agent network whose individual action values are combined through a mixing network to estimate the joint action-value function. By coupling imagined trajectories with SALE-based action values, the agents acquire a richer understanding of how their choices influence collective outcomes, leading to improved long-term planning and optimization under limited real-environment interactions. Empirical studies on well-established multi-agent benchmarks, including StarCraft II Micro-Management, Multi-Agent MuJoCo, and Level-Based Foraging challenges, demonstrate consistent gains of our method over baseline algorithms and highlight the effectiveness of joint state-action learned embeddings within a multi-agent model-based paradigm.
Paper Structure (27 sections, 16 equations, 13 figures, 10 tables)

This paper contains 27 sections, 16 equations, 13 figures, 10 tables.

Figures (13)

  • Figure 1: Architecture of the SALE-augmented policy and agent networks in MMSA. Top: the policy network in which the state $s_t$ is encoded and passed into $\pi_t$ to produce the action. Bottom: the agent network, which encodes the observation and action for computing $Q_i\bigl( z^o_t,\; z^{oa}_t,\;\phi^{oa}_t\bigr).$
  • Figure 2: Illustration of the world model imagination in MMSA. The input $\mathbf{h}_{t}$ encapsulates the past information, including $\hat{s}_{t-1}$ and $\mathbf{a}_{t-1}$. The agent networks receive $\mathbf{h}_{t}$ and infer $\hat{\mathbf{a}}_{t}$. Taking the normalized joint state-action learned embeddings $({z}_{t}^{\hat{s\mathbf{a}}}, z_{t}^{\hat{s}}, {\phi}_{t}^{\hat{s\mathbf{a}}})$ as input, VAE-1 reconstructs ${z'}_{t}^{\hat{s\mathbf{a}}}, {z'}_{t}^{\hat{s}},$ and ${\phi'}_{t}^{\hat{s\mathbf{a}}}$. The outputs are passed into VAE-2 to infer $\mathbf{h'}_{t+1}$.
  • Figure 3: An overview of the MMSA method that illustrates how our model-based MARL framework weaves together (1) a learned world model with state-action learned embeddings, (2) decentralized agent value networks equipped with SALE, and (3) a QMIX-style mixing network under the CTDE paradigm. The learning process of the world model is shown in Figure \ref{['fig:wmodel']}.
  • Figure 4: Performance of MMSA in Multi-Agent MuJoCo (top row) and in Level-Based Foraging (bottom row). The shaded region captures a $95\%$ confidence interval around the average performance. Top: Comparison of the average episodic return of MMSA with competing MARL algorithms in Multi-Agent MuJoCo tasks. The return is scaled for clear plotting. Experiments are run for 7M time steps. Bottom: The mean episodic return of MMSA compared to other MARL methods in Level-Based Foraging. Each run lasts 2M time steps. In both MARL benchmarks, MMSA excels the competitors in all of the environments.
  • Figure 5: Performance of MMSA compared with MARL baselines and ablations in SMACv2. (a) Test win rates of MMSA compared with top-performing methods in SMACv2. We plot the median test win rates with the $25\% - 75\%$ percentiles, as in jianye2023boosting. Each run lasts 5M time steps. Although MMSA shows a slow start, it gradually outruns the baselines such as VDN and QMIX. It exhibits an overall performance matching that of HPN-QMIX, the best-competing method. (b) Ablations for the MMSA architecture. MMSA is compared against the variants in which the world model, SALE, KL balancing, or global state is removed, respectively (No-WM, No-SALE, No-KLB, or No-GS). Performance is averaged over all SMACv2 challenges. Each run lasts 3M time steps.
  • ...and 8 more figures

Theorems & Definitions (2)

  • Definition D.1
  • Definition E.1