A suite of sequential social dilemma environments for multi-agent reinforcement learning in JAX
Common Rewards : a scenario where all agents share a single, unified reward signal. This approach ensures that all agents are aligned towards achieving the same objective, promoting collaboration and coordination among them.
Individual Rewards: each agent is assigned its own reward, inherently encouraging selfish behavior.
SocialJax leverages JAX's high-performance GPU capabilities to accelerate multi-agent reinforcement learning in sequential social dilemmas. We are committed to providing a more efficient and diverse suite of environments for studying social dilemmas. We provide JAX implementations of the following environments: Coins, Commons Harvest: Open, Commons Harvest: Closed, Clean Up, Territory, and Coop Mining, which are derived from Melting Pot 2.0 and feature commonly studied mixed incentives.
Our blog presents more details and analysis on agents' policy and performance.
[2025/04/29] 🚀 Updated Mushrooms environment.
[2025/04/28] 🚀 Updated Gift Refinement environment.
[2025/04/16] ✨ Added MAPPO algorithm for all environments.
First: Clone the repository
git clone https://github.com/cooperativex/SocialJax.git
cd SocialJaxSecond: Environment Setup.
Option one: Using peotry, make sure you have python 3.10
-
Install Peotry
curl -sSL https://install.python-poetry.org | python3 - export PATH="$HOME/.local/bin:$PATH"
-
Install requirements
poetry install --no-root poetry run pip install jaxlib==0.4.23+cuda11.cudnn86 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
export PYTHONPATH=./socialjax:$PYTHONPATH
-
Run code
poetry run python algothrims/IPPO/ippo_cnn_coins.py
Option two: conda with requirements.txt
-
Conda
conda create -n SocialJax python=3.10 conda activate SocialJax
-
Install requirements
pip install -r requirements.txt pip install jaxlib==0.4.23+cuda11.cudnn86 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
export PYTHONPATH=./socialjax:$PYTHONPATH
-
Run code
python algothrims/IPPO/ippo_cnn_coins.py
Option three: conda with environments.yml
-
Install requirements
conda env create -f environment.yml
export PYTHONPATH=./socialjax:$PYTHONPATH
-
Run code
python algothrims/IPPO/ippo_cnn_coins.py
We introduce the environments and use Schelling diagrams to demonstrate whether the environments are social dilemmas.
| Environment | Description | Schelling Diagrams Proof |
|---|---|---|
| Coins | Link | ✓ |
| Commons Harvest: Open | Link | ✓ |
| Commons Harvest: Closed | Link | ✓ |
| Clean Up | Link | ✓ |
| Territory | Link | ✗ |
| Coop Mining | Link | ✓ |
-
Due to algorithmic limitations, agents may not always learn the optimal actions. As a result, Schelling diagrams can prove that the environment is social dilemmas, but they cannot definitively prove that the environment is not social dilemmas.
-
Territory might not be Social diagram, but as long as the agents' behaviors are interesting, Territory holds intrinsic value.
SocialJax interfaces follow JaxMARL which takes inspiration from the PettingZoo and Gymnax.
You can create an environment using the make function:
import jax
import socialjax
env = make('clean_up')Find more fixed policy examples.
import jax
import socialjax
from socialjax import make
num_agents = 7
env = make('clean_up', num_agents=num_agents)
rng = jax.random.PRNGKey(259)
rng, _rng = jax.random.split(rng)
for t in range(100):
rng, *rngs = jax.random.split(rng, num_agents+1)
actions = [jax.random.choice(
rngs[a],
a=env.action_space(0).n,
p=jnp.array([0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1])
) for a in range(num_agents)]
obs, state, reward, done, info = env.step_env(
rng, old_state, [a for a in actions]
)You can test the speed of our environments by running speed_test_random.py or using the colab.
@article{guo2025socialjax,
title={SocialJax: An Evaluation Suite for Multi-agent Reinforcement Learning in Sequential Social Dilemmas},
author={Guo, Zihao and Willis, Richard and Shi, Shuqing and Tomilin, Tristan and Leibo, Joel Z and Du, Yali},
journal={arXiv preprint arXiv:2503.14576},
year={2025}
}
JaxMARL: accelerated MARL environments with baselines in JAX.
PureJaxRL: JAX implementation of PPO, and demonstration of end-to-end JAX-based RL training.









