diff --git a/Dockerfile b/Dockerfile index b3299296..c4cb6440 100644 --- a/Dockerfile +++ b/Dockerfile @@ -17,7 +17,7 @@ RUN apt-get update && \ apt-get install -y tmux #jaxmarl from source if needed, all the requirements -RUN pip install -e . +RUN pip install -e .[algs,dev] USER ${MYUSER} diff --git a/README.md b/README.md index 22f14690..aa1268c9 100644 --- a/README.md +++ b/README.md @@ -36,6 +36,8 @@ ## Multi-Agent Reinforcement Learning in JAX +🎉 **Update: JaxMARL was accepted at NeurIPS 2024 on Datasets and Benchmarks Track. See you in Vacouver!** + JaxMARL combines ease-of-use with GPU-enabled efficiency, and supports a wide range of commonly used MARL environments as well as popular baseline algorithms. Our aim is for one library that enables thorough evaluation of MARL methods across a wide range of tasks and against relevant baselines. We also introduce SMAX, a vectorised, simplified version of the popular StarCraft Multi-Agent Challenge, which removes the need to run the StarCraft II game engine. For more details, take a look at our [blog post](https://blog.foersterlab.com/jaxmarl/) or our [Colab notebook](https://colab.research.google.com/github/FLAIROx/JaxMARL/blob/main/jaxmarl/tutorials/JaxMARL_Walkthrough.ipynb), which walks through the basic usage. @@ -72,7 +74,7 @@ We follow CleanRL's philosophy of providing single file implementations which ca

Installation 🧗

-**Environments** - Before installing, ensure you have the correct [JAX version](https://github.com/google/jax#installation) for your hardware accelerator. The JaxMARL environments can be installed directly from PyPi: +**Environments** - Before installing, ensure you have the correct [JAX installation](https://github.com/google/jax#installation) for your hardware accelerator. We have tested up to JAX version 0.4.25. The JaxMARL environments can be installed directly from PyPi: ``` pip install jaxmarl @@ -84,11 +86,15 @@ pip install jaxmarl ``` git clone https://github.com/FLAIROx/JaxMARL.git && cd JaxMARL ``` -2. The requirements for IPPO & MAPPO can be installed with: +2. Install requirements: ``` - pip install -e . + pip install -e .[algs] export PYTHONPATH=./JaxMARL:$PYTHONPATH ``` +3. For the fastest start, we reccoment using our Dockerfile, the usage of which is outlined below. + +**Development** - If you would like to run our test suite, install the additonal dependencies with: + `pip install -e .[dev]`, after cloning the repository.

Quick Start 🚀

@@ -151,6 +157,7 @@ JAX-native algorithms: - [Mava](https://github.com/instadeepai/Mava): JAX implementations of IPPO and MAPPO, two popular MARL algorithms. - [PureJaxRL](https://github.com/luchris429/purejaxrl): JAX implementation of PPO, and demonstration of end-to-end JAX-based RL training. - [Minimax](https://github.com/facebookresearch/minimax/): JAX implementations of autocurricula baselines for RL. +- [JaxIRL](https://github.com/FLAIROx/jaxirl?tab=readme-ov-file): JAX implementation of algorithms for inverse reinforcement learning. JAX-native environments: - [Gymnax](https://github.com/RobertTLange/gymnax): Implementations of classic RL tasks including classic control, bsuite and MinAtar. @@ -158,3 +165,4 @@ JAX-native environments: - [Pgx](https://github.com/sotetsuk/pgx): JAX implementations of classic board games, such as Chess, Go and Shogi. - [Brax](https://github.com/google/brax): A fully differentiable physics engine written in JAX, features continuous control tasks. - [XLand-MiniGrid](https://github.com/corl-team/xland-minigrid): Meta-RL gridworld environments inspired by XLand and MiniGrid. +- [Craftax](https://github.com/MichaelTMatthews/Craftax): (Crafter + NetHack) in JAX. \ No newline at end of file diff --git a/jaxmarl/__init__.py b/jaxmarl/__init__.py index de0658fe..962e720b 100644 --- a/jaxmarl/__init__.py +++ b/jaxmarl/__init__.py @@ -1,4 +1,4 @@ from .registration import make, registered_envs __all__ = ["make", "registered_envs"] -__version__ = "0.0.5" +__version__ = "0.0.6" diff --git a/jaxmarl/environments/hanabi/hanabi.py b/jaxmarl/environments/hanabi/hanabi.py index aaf2e799..bc47d88c 100644 --- a/jaxmarl/environments/hanabi/hanabi.py +++ b/jaxmarl/environments/hanabi/hanabi.py @@ -9,7 +9,7 @@ import chex from typing import Tuple, Dict from functools import partial -from gymnax.environments.spaces import Discrete +from jaxmarl.environments.spaces import Discrete from .hanabi_game import HanabiGame, State diff --git a/jaxmarl/environments/mabrax/mabrax_env.py b/jaxmarl/environments/mabrax/mabrax_env.py index ce170927..f4edb783 100644 --- a/jaxmarl/environments/mabrax/mabrax_env.py +++ b/jaxmarl/environments/mabrax/mabrax_env.py @@ -1,7 +1,7 @@ from typing import Dict, Literal, Optional, Tuple import chex from jaxmarl.environments.multi_agent_env import MultiAgentEnv -from gymnax.environments import spaces +from jaxmarl.environments import spaces from brax import envs import jax import jax.numpy as jnp diff --git a/jaxmarl/environments/mpe/simple.py b/jaxmarl/environments/mpe/simple.py index 59a32c7d..426e3659 100644 --- a/jaxmarl/environments/mpe/simple.py +++ b/jaxmarl/environments/mpe/simple.py @@ -10,7 +10,7 @@ from jaxmarl.environments.multi_agent_env import MultiAgentEnv from jaxmarl.environments.mpe.default_params import * import chex -from gymnax.environments.spaces import Box, Discrete +from jaxmarl.environments.spaces import Box, Discrete from flax import struct from typing import Tuple, Optional, Dict from functools import partial diff --git a/jaxmarl/environments/mpe/simple_adversary.py b/jaxmarl/environments/mpe/simple_adversary.py index 5273602b..f2706629 100644 --- a/jaxmarl/environments/mpe/simple_adversary.py +++ b/jaxmarl/environments/mpe/simple_adversary.py @@ -5,7 +5,7 @@ from functools import partial from jaxmarl.environments.mpe.simple import State, SimpleMPE from jaxmarl.environments.mpe.default_params import * -from gymnax.environments.spaces import Box +from jaxmarl.environments.spaces import Box class SimpleAdversaryMPE(SimpleMPE): diff --git a/jaxmarl/environments/mpe/simple_crypto.py b/jaxmarl/environments/mpe/simple_crypto.py index 1ce4f09a..d1da2d5d 100644 --- a/jaxmarl/environments/mpe/simple_crypto.py +++ b/jaxmarl/environments/mpe/simple_crypto.py @@ -6,7 +6,7 @@ from functools import partial from jaxmarl.environments.mpe.simple import SimpleMPE, State from jaxmarl.environments.mpe.default_params import * -from gymnax.environments.spaces import Box, Discrete +from jaxmarl.environments.spaces import Box, Discrete SPEAKER = "alice_0" LISTENER = "bob_0" diff --git a/jaxmarl/environments/mpe/simple_facmac.py b/jaxmarl/environments/mpe/simple_facmac.py index e7b18970..1c655ecf 100644 --- a/jaxmarl/environments/mpe/simple_facmac.py +++ b/jaxmarl/environments/mpe/simple_facmac.py @@ -4,7 +4,7 @@ from typing import Tuple, Dict from functools import partial from jaxmarl.environments.mpe.simple import State, SimpleMPE -from gymnax.environments.spaces import Box +from jaxmarl.environments.spaces import Box from jaxmarl.environments.mpe.default_params import * diff --git a/jaxmarl/environments/mpe/simple_push.py b/jaxmarl/environments/mpe/simple_push.py index bbfa37ce..72d23ea0 100644 --- a/jaxmarl/environments/mpe/simple_push.py +++ b/jaxmarl/environments/mpe/simple_push.py @@ -5,7 +5,7 @@ from functools import partial from jaxmarl.environments.mpe.simple import SimpleMPE, State from jaxmarl.environments.mpe.default_params import * -from gymnax.environments.spaces import Box +from jaxmarl.environments.spaces import Box # Obstacle Colours COLOUR_1 = jnp.array([0.1, 0.9, 0.1]) diff --git a/jaxmarl/environments/mpe/simple_reference.py b/jaxmarl/environments/mpe/simple_reference.py index b86314bb..ae70e482 100644 --- a/jaxmarl/environments/mpe/simple_reference.py +++ b/jaxmarl/environments/mpe/simple_reference.py @@ -5,7 +5,7 @@ from functools import partial from jaxmarl.environments.mpe.simple import SimpleMPE, State from jaxmarl.environments.mpe.default_params import * -from gymnax.environments.spaces import Box, Discrete +from jaxmarl.environments.spaces import Box, Discrete # Obstacle Colours OBS_COLOUR = [(191, 64, 64), (64, 191, 64), (64, 64, 191)] diff --git a/jaxmarl/environments/mpe/simple_speaker_listener.py b/jaxmarl/environments/mpe/simple_speaker_listener.py index 8d1ee78e..c3ddeacb 100644 --- a/jaxmarl/environments/mpe/simple_speaker_listener.py +++ b/jaxmarl/environments/mpe/simple_speaker_listener.py @@ -4,7 +4,7 @@ from typing import Tuple, Dict from jaxmarl.environments.mpe.simple import SimpleMPE, State from jaxmarl.environments.mpe.default_params import * -from gymnax.environments.spaces import Box, Discrete +from jaxmarl.environments.spaces import Box, Discrete SPEAKER = "speaker_0" LISTENER = "listener_0" diff --git a/jaxmarl/environments/mpe/simple_spread.py b/jaxmarl/environments/mpe/simple_spread.py index ebabe61d..222c5818 100644 --- a/jaxmarl/environments/mpe/simple_spread.py +++ b/jaxmarl/environments/mpe/simple_spread.py @@ -5,7 +5,7 @@ from functools import partial from jaxmarl.environments.mpe.simple import SimpleMPE, State from jaxmarl.environments.mpe.default_params import * -from gymnax.environments.spaces import Box +from jaxmarl.environments.spaces import Box class SimpleSpreadMPE(SimpleMPE): diff --git a/jaxmarl/environments/mpe/simple_tag.py b/jaxmarl/environments/mpe/simple_tag.py index bf9c0869..813b032e 100644 --- a/jaxmarl/environments/mpe/simple_tag.py +++ b/jaxmarl/environments/mpe/simple_tag.py @@ -4,7 +4,7 @@ from typing import Tuple, Dict from functools import partial from jaxmarl.environments.mpe.simple import SimpleMPE, State -from gymnax.environments.spaces import Box +from jaxmarl.environments.spaces import Box from jaxmarl.environments.mpe.default_params import * diff --git a/jaxmarl/environments/mpe/simple_world_comm.py b/jaxmarl/environments/mpe/simple_world_comm.py index 4343e212..e22a2643 100644 --- a/jaxmarl/environments/mpe/simple_world_comm.py +++ b/jaxmarl/environments/mpe/simple_world_comm.py @@ -11,8 +11,7 @@ OBS_COLOUR, ) from jaxmarl.environments.mpe.default_params import * -from gymnax.environments.spaces import Box, Discrete - +from jaxmarl.environments.spaces import Box, Discrete # NOTE food and forests are part of world.landmarks diff --git a/jaxmarl/environments/spaces.py b/jaxmarl/environments/spaces.py index 4320bd46..8f97d6c9 100644 --- a/jaxmarl/environments/spaces.py +++ b/jaxmarl/environments/spaces.py @@ -1,3 +1,4 @@ +""" Built off Gymnax spaces.py, this module contains jittable classes for action and observation spaces. """ from typing import Tuple, Union, Sequence from collections import OrderedDict import chex diff --git a/pyproject.toml b/pyproject.toml index 3f5bfbba..a7c5a865 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,6 @@ include = ['jaxmarl*'] [tool.setuptools.dynamic] version = {attr = "jaxmarl.__version__"} -dependencies = {file = ["requirements/requirements.txt"]} [project] name = "jaxmarl" @@ -17,7 +16,7 @@ description = "Multi-Agent Reinforcement Learning with JAX" authors = [ {name = "Foerster Lab for AI Research", email = "arutherford@robots.ox.ac.uk"}, ] -dynamic = ["version", "dependencies"] +dynamic = ["version"] license = {file = "LICENSE"} requires-python = ">=3.10" classifiers = [ @@ -31,6 +30,35 @@ classifiers = [ "Topic :: Software Development :: Libraries :: Python Modules", "License :: OSI Approved :: Apache Software License", ] +dependencies = [ + "jax>=0.4.16.0,<=0.4.25", + "jaxlib>=0.4.16.0,<=0.4.25", + "flax", + "safetensors", + "chex", + "brax==0.10.3", + "mujoco==3.1.3", + "matplotlib", + "pillow", + "scipy<=1.12", + "gymnax", +] + +[project.optional-dependencies] +algs = [ + "optax", + "distrax", + "flashbax==0.1.0", + "wandb", + "hydra-core>=1.3.2", + "omegaconf>=2.3.0", + "pettingzoo>=1.24.3", + "tqdm>=4.66.0", +] +dev = [ + "pytest", + "pygame", +] [project.urls] "Homepage" = "https://github.com/FLAIROx/JaxMARL" diff --git a/requirements/requirements.txt b/requirements/requirements.txt deleted file mode 100644 index 2e0fd4b8..00000000 --- a/requirements/requirements.txt +++ /dev/null @@ -1,26 +0,0 @@ -# requirements are alligned with nvcr.io/nvidia/jax:23.10-py3 image -jax>=0.4.16.0,<=0.4.25 -jaxlib>=0.4.16.0,<=0.4.25 -flax==0.7.4 -chex==0.1.84 -optax==0.1.7 -dotmap==1.3.30 -evosax==0.1.5 -distrax==0.1.5 -brax==0.10.3 -mujoco==3.1.3 -gymnax==0.0.6 -safetensors==0.4.2 -flashbax==0.1.0 -# less sensitive libs -wandb -pytest -pygame -numpy>=1.26.1 -hydra-core>=1.3.2 -omegaconf>=2.3.0 -matplotlib>=3.8.3 -pillow>=10.2.0 -pettingzoo>=1.24.3 -tqdm>=4.66.0 -scipy<=1.12 diff --git a/tests/hanabi/test_hanabi.py b/tests/hanabi/test_hanabi.py index 4c76e1a5..b6fcfb62 100644 --- a/tests/hanabi/test_hanabi.py +++ b/tests/hanabi/test_hanabi.py @@ -4,7 +4,6 @@ import jax from jax import numpy as jnp from jaxmarl import make -from jaxmarl.wrappers.baselines import LogWrapper env = make("hanabi") dir_path = os.path.dirname(os.path.realpath(__file__))