Skip to content

cooperativex/SocialJax

Repository files navigation

SocialJax

arXiv Apache 2.0 License Pylint Status

A suite of sequential social dilemma environments for multi-agent reinforcement learning in JAX

coins_common harvest_open_common harvest_closed_common clean_up_common coop_mining_common

Common Rewards : a scenario where all agents share a single, unified reward signal. This approach ensures that all agents are aligned towards achieving the same objective, promoting collaboration and coordination among them.

coins_individual harvest_open_individual harvest_closed_individual clean_up_individual coop_mining_individual

Individual Rewards: each agent is assigned its own reward, inherently encouraging selfish behavior.

SocialJax leverages JAX's high-performance GPU capabilities to accelerate multi-agent reinforcement learning in sequential social dilemmas. We are committed to providing a more efficient and diverse suite of environments for studying social dilemmas. We provide JAX implementations of the following environments: Coins, Commons Harvest: Open, Commons Harvest: Closed, Clean Up, Territory, and Coop Mining, which are derived from Melting Pot 2.0 and feature commonly studied mixed incentives.

Our blog presents more details and analysis on agents' policy and performance.

Update

[2025/05/28] ✨ Updated SVO algorithm for all environments.

[2025/04/29] 🚀 Updated Mushrooms environment.

[2025/04/28] 🚀 Updated Gift Refinement environment.

[2025/04/16] ✨ Added MAPPO algorithm for all environments.

Installation

First: Clone the repository

Second: Environment Setup.

Option one: Using peotry, make sure you have python 3.10

  1. Install Peotry

    curl -sSL https://install.python-poetry.org | python3 -
    export PATH="$HOME/.local/bin:$PATH"
  2. Install requirements

    poetry install --no-root
    poetry run pip install jaxlib==0.4.23+cuda11.cudnn86 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
    export PYTHONPATH=./socialjax:$PYTHONPATH
  3. Run code

    poetry run python algorithms/IPPO/ippo_cnn_coins.py 

Option two: conda with requirements.txt

  1. Conda

    conda create -n SocialJax python=3.10
    conda activate SocialJax
  2. Install requirements

    pip install -r requirements.txt
    pip install jaxlib==0.4.23+cuda11.cudnn86 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
    export PYTHONPATH=./socialjax:$PYTHONPATH
  3. Run code

    python algorithms/IPPO/ippo_cnn_coins.py 

Option three: conda with environments.yml

  1. Install requirements

    conda env create -f environment.yml
    export PYTHONPATH=./socialjax:$PYTHONPATH
  2. Run code

    python algorithms/IPPO/ippo_cnn_coins.py 

Environments

We introduce the environments and use Schelling diagrams to demonstrate whether the environments are social dilemmas.

Environment Description Schelling Diagrams Proof
Coins Link
Commons Harvest: Open Link
Commons Harvest: Closed Link
Commons Harvest: partnership Link
Clean Up Link
Territory Link
Coop Mining Link
Mushrooms Link
Gift Refinement Link
Prisoners Dilemma: Arena Link

Important Notes:

  • Due to algorithmic limitations, agents may not always learn the optimal actions. As a result, Schelling diagrams can prove that the environment is social dilemmas, but they cannot definitively prove that the environment is not social dilemmas.

  • Territory might not be Social diagram, but as long as the agents' behaviors are interesting, Territory holds intrinsic value.

Quick Start

SocialJax interfaces follow JaxMARL which takes inspiration from the PettingZoo and Gymnax.

Make an Environment

You can create an environment using the make function:

import jax
import socialjax

env = make('clean_up')

Example

Find more fixed policy examples.

import jax
import socialjax
from socialjax import make

num_agents = 7
env = make('clean_up', num_agents=num_agents)
rng = jax.random.PRNGKey(259)
rng, _rng = jax.random.split(rng)

for t in range(100):
     rng, *rngs = jax.random.split(rng, num_agents+1)
     actions = [jax.random.choice(
          rngs[a],
          a=env.action_space(0).n,
          p=jnp.array([0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1])
     ) for a in range(num_agents)]

     obs, state, reward, done, info = env.step_env(
          rng, old_state, [a for a in actions]
            )

Speed test

You can test the speed of our environments by running speed_test_random.py or using the colab.

See Also

JaxMARL: accelerated MARL environments with baselines in JAX.

PureJaxRL: JAX implementation of PPO, and demonstration of end-to-end JAX-based RL training.

About

SocialJax: sequential social dilemma environments

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 3

  •  
  •  
  •