Table of Contents
Fetching ...

SocialJax: An Evaluation Suite for Multi-agent Reinforcement Learning in Sequential Social Dilemmas

Zihao Guo, Shuqing Shi, Richard Willis, Tristan Tomilin, Joel Z. Leibo, Yali Du

TL;DR

SocialJax delivers a GPU-accelerated, JAX-based evaluation suite for sequential social dilemmas in multi-agent reinforcement learning, addressing the computational bottlenecks of CPU-based benchmarks. By implementing nine grid-world environments and a set of MARL algorithms (IPPO, MAPPO, SVO) with both common and individual rewards, the framework enables fast, reproducible experimentation and robust social-dilemma validation via Schelling diagrams. Key contributions include substantial real-time speed-ups (e.g., over Melting Pot baselines), environment-specific cooperation metrics, and a unified, open-source pipeline that facilitates principled evaluation of prosocial vs. antisocial behaviors. The work has practical impact by lowering the computational barrier to thorough SSD research and providing a standardized platform for comparing MARL approaches under realistic social incentives.

Abstract

Sequential social dilemmas pose a significant challenge in the field of multi-agent reinforcement learning (MARL), requiring environments that accurately reflect the tension between individual and collective interests. Previous benchmarks and environments, such as Melting Pot, provide an evaluation protocol that measures generalization to new social partners in various test scenarios. However, running reinforcement learning algorithms in traditional environments requires substantial computational resources. In this paper, we introduce SocialJax, a suite of sequential social dilemma environments and algorithms implemented in JAX. JAX is a high-performance numerical computing library for Python that enables significant improvements in operational efficiency. Our experiments demonstrate that the SocialJax training pipeline achieves at least 50\texttimes{} speed-up in real-time performance compared to Melting Pot RLlib baselines. Additionally, we validate the effectiveness of baseline algorithms within SocialJax environments. Finally, we use Schelling diagrams to verify the social dilemma properties of these environments, ensuring that they accurately capture the dynamics of social dilemmas.

SocialJax: An Evaluation Suite for Multi-agent Reinforcement Learning in Sequential Social Dilemmas

TL;DR

SocialJax delivers a GPU-accelerated, JAX-based evaluation suite for sequential social dilemmas in multi-agent reinforcement learning, addressing the computational bottlenecks of CPU-based benchmarks. By implementing nine grid-world environments and a set of MARL algorithms (IPPO, MAPPO, SVO) with both common and individual rewards, the framework enables fast, reproducible experimentation and robust social-dilemma validation via Schelling diagrams. Key contributions include substantial real-time speed-ups (e.g., over Melting Pot baselines), environment-specific cooperation metrics, and a unified, open-source pipeline that facilitates principled evaluation of prosocial vs. antisocial behaviors. The work has practical impact by lowering the computational barrier to thorough SSD research and providing a standardized platform for comparing MARL approaches under realistic social incentives.

Abstract

Sequential social dilemmas pose a significant challenge in the field of multi-agent reinforcement learning (MARL), requiring environments that accurately reflect the tension between individual and collective interests. Previous benchmarks and environments, such as Melting Pot, provide an evaluation protocol that measures generalization to new social partners in various test scenarios. However, running reinforcement learning algorithms in traditional environments requires substantial computational resources. In this paper, we introduce SocialJax, a suite of sequential social dilemma environments and algorithms implemented in JAX. JAX is a high-performance numerical computing library for Python that enables significant improvements in operational efficiency. Our experiments demonstrate that the SocialJax training pipeline achieves at least 50\texttimes{} speed-up in real-time performance compared to Melting Pot RLlib baselines. Additionally, we validate the effectiveness of baseline algorithms within SocialJax environments. Finally, we use Schelling diagrams to verify the social dilemma properties of these environments, ensuring that they accurately capture the dynamics of social dilemmas.

Paper Structure

This paper contains 26 sections, 5 equations, 11 figures, 6 tables.

Figures (11)

  • Figure 1: The SocialJax suite contains nine multi-agent reinforcement learning environments designed to evaluate social dilemmas. The screenshots in this figure were taken from a third-person visualization perspective. The agents of all environments are restricted to partial observability of their surroundings through a designated observation window.
  • Figure 2: Training curves for a range of SocialJax environments. IPPO (shared parameters) with Common Reward encourages collective interests, leading to higher overall returns, Individual Reward primarily drives selfish behavior, often resulting in lower returns. MAPPO is included as a centralized baseline. The SVO curve represents the case with a social angle $\theta=90^\circ$.
  • Figure 3: Schelling diagrams of the environments, visualizing the relationship between individual incentives and collective outcomes. The dotted line represents the overall average return of cooperators and defectors.
  • Figure 4: An example of using the SocialJax API in the Clean Up environment.
  • Figure 5: Command‑line example for training different algorithms (IPPO and MAPPO) and generating training curves with SocialJax.
  • ...and 6 more figures