diff --git a/Dockerfile b/Dockerfile
index b3299296..c4cb6440 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -17,7 +17,7 @@ RUN apt-get update && \
apt-get install -y tmux
#jaxmarl from source if needed, all the requirements
-RUN pip install -e .
+RUN pip install -e .[algs,dev]
USER ${MYUSER}
diff --git a/README.md b/README.md
index 22f14690..aa1268c9 100644
--- a/README.md
+++ b/README.md
@@ -36,6 +36,8 @@
## Multi-Agent Reinforcement Learning in JAX
+🎉 **Update: JaxMARL was accepted at NeurIPS 2024 on Datasets and Benchmarks Track. See you in Vacouver!**
+
JaxMARL combines ease-of-use with GPU-enabled efficiency, and supports a wide range of commonly used MARL environments as well as popular baseline algorithms. Our aim is for one library that enables thorough evaluation of MARL methods across a wide range of tasks and against relevant baselines. We also introduce SMAX, a vectorised, simplified version of the popular StarCraft Multi-Agent Challenge, which removes the need to run the StarCraft II game engine.
For more details, take a look at our [blog post](https://blog.foersterlab.com/jaxmarl/) or our [Colab notebook](https://colab.research.google.com/github/FLAIROx/JaxMARL/blob/main/jaxmarl/tutorials/JaxMARL_Walkthrough.ipynb), which walks through the basic usage.
@@ -72,7 +74,7 @@ We follow CleanRL's philosophy of providing single file implementations which ca
Installation 🧗
-**Environments** - Before installing, ensure you have the correct [JAX version](https://github.com/google/jax#installation) for your hardware accelerator. The JaxMARL environments can be installed directly from PyPi:
+**Environments** - Before installing, ensure you have the correct [JAX installation](https://github.com/google/jax#installation) for your hardware accelerator. We have tested up to JAX version 0.4.25. The JaxMARL environments can be installed directly from PyPi:
```
pip install jaxmarl
@@ -84,11 +86,15 @@ pip install jaxmarl
```
git clone https://github.com/FLAIROx/JaxMARL.git && cd JaxMARL
```
-2. The requirements for IPPO & MAPPO can be installed with:
+2. Install requirements:
```
- pip install -e .
+ pip install -e .[algs]
export PYTHONPATH=./JaxMARL:$PYTHONPATH
```
+3. For the fastest start, we reccoment using our Dockerfile, the usage of which is outlined below.
+
+**Development** - If you would like to run our test suite, install the additonal dependencies with:
+ `pip install -e .[dev]`, after cloning the repository.
Quick Start 🚀
@@ -151,6 +157,7 @@ JAX-native algorithms:
- [Mava](https://github.com/instadeepai/Mava): JAX implementations of IPPO and MAPPO, two popular MARL algorithms.
- [PureJaxRL](https://github.com/luchris429/purejaxrl): JAX implementation of PPO, and demonstration of end-to-end JAX-based RL training.
- [Minimax](https://github.com/facebookresearch/minimax/): JAX implementations of autocurricula baselines for RL.
+- [JaxIRL](https://github.com/FLAIROx/jaxirl?tab=readme-ov-file): JAX implementation of algorithms for inverse reinforcement learning.
JAX-native environments:
- [Gymnax](https://github.com/RobertTLange/gymnax): Implementations of classic RL tasks including classic control, bsuite and MinAtar.
@@ -158,3 +165,4 @@ JAX-native environments:
- [Pgx](https://github.com/sotetsuk/pgx): JAX implementations of classic board games, such as Chess, Go and Shogi.
- [Brax](https://github.com/google/brax): A fully differentiable physics engine written in JAX, features continuous control tasks.
- [XLand-MiniGrid](https://github.com/corl-team/xland-minigrid): Meta-RL gridworld environments inspired by XLand and MiniGrid.
+- [Craftax](https://github.com/MichaelTMatthews/Craftax): (Crafter + NetHack) in JAX.
\ No newline at end of file
diff --git a/jaxmarl/__init__.py b/jaxmarl/__init__.py
index de0658fe..962e720b 100644
--- a/jaxmarl/__init__.py
+++ b/jaxmarl/__init__.py
@@ -1,4 +1,4 @@
from .registration import make, registered_envs
__all__ = ["make", "registered_envs"]
-__version__ = "0.0.5"
+__version__ = "0.0.6"
diff --git a/jaxmarl/environments/hanabi/hanabi.py b/jaxmarl/environments/hanabi/hanabi.py
index aaf2e799..bc47d88c 100644
--- a/jaxmarl/environments/hanabi/hanabi.py
+++ b/jaxmarl/environments/hanabi/hanabi.py
@@ -9,7 +9,7 @@
import chex
from typing import Tuple, Dict
from functools import partial
-from gymnax.environments.spaces import Discrete
+from jaxmarl.environments.spaces import Discrete
from .hanabi_game import HanabiGame, State
diff --git a/jaxmarl/environments/mabrax/mabrax_env.py b/jaxmarl/environments/mabrax/mabrax_env.py
index ce170927..f4edb783 100644
--- a/jaxmarl/environments/mabrax/mabrax_env.py
+++ b/jaxmarl/environments/mabrax/mabrax_env.py
@@ -1,7 +1,7 @@
from typing import Dict, Literal, Optional, Tuple
import chex
from jaxmarl.environments.multi_agent_env import MultiAgentEnv
-from gymnax.environments import spaces
+from jaxmarl.environments import spaces
from brax import envs
import jax
import jax.numpy as jnp
diff --git a/jaxmarl/environments/mpe/simple.py b/jaxmarl/environments/mpe/simple.py
index 59a32c7d..426e3659 100644
--- a/jaxmarl/environments/mpe/simple.py
+++ b/jaxmarl/environments/mpe/simple.py
@@ -10,7 +10,7 @@
from jaxmarl.environments.multi_agent_env import MultiAgentEnv
from jaxmarl.environments.mpe.default_params import *
import chex
-from gymnax.environments.spaces import Box, Discrete
+from jaxmarl.environments.spaces import Box, Discrete
from flax import struct
from typing import Tuple, Optional, Dict
from functools import partial
diff --git a/jaxmarl/environments/mpe/simple_adversary.py b/jaxmarl/environments/mpe/simple_adversary.py
index 5273602b..f2706629 100644
--- a/jaxmarl/environments/mpe/simple_adversary.py
+++ b/jaxmarl/environments/mpe/simple_adversary.py
@@ -5,7 +5,7 @@
from functools import partial
from jaxmarl.environments.mpe.simple import State, SimpleMPE
from jaxmarl.environments.mpe.default_params import *
-from gymnax.environments.spaces import Box
+from jaxmarl.environments.spaces import Box
class SimpleAdversaryMPE(SimpleMPE):
diff --git a/jaxmarl/environments/mpe/simple_crypto.py b/jaxmarl/environments/mpe/simple_crypto.py
index 1ce4f09a..d1da2d5d 100644
--- a/jaxmarl/environments/mpe/simple_crypto.py
+++ b/jaxmarl/environments/mpe/simple_crypto.py
@@ -6,7 +6,7 @@
from functools import partial
from jaxmarl.environments.mpe.simple import SimpleMPE, State
from jaxmarl.environments.mpe.default_params import *
-from gymnax.environments.spaces import Box, Discrete
+from jaxmarl.environments.spaces import Box, Discrete
SPEAKER = "alice_0"
LISTENER = "bob_0"
diff --git a/jaxmarl/environments/mpe/simple_facmac.py b/jaxmarl/environments/mpe/simple_facmac.py
index e7b18970..1c655ecf 100644
--- a/jaxmarl/environments/mpe/simple_facmac.py
+++ b/jaxmarl/environments/mpe/simple_facmac.py
@@ -4,7 +4,7 @@
from typing import Tuple, Dict
from functools import partial
from jaxmarl.environments.mpe.simple import State, SimpleMPE
-from gymnax.environments.spaces import Box
+from jaxmarl.environments.spaces import Box
from jaxmarl.environments.mpe.default_params import *
diff --git a/jaxmarl/environments/mpe/simple_push.py b/jaxmarl/environments/mpe/simple_push.py
index bbfa37ce..72d23ea0 100644
--- a/jaxmarl/environments/mpe/simple_push.py
+++ b/jaxmarl/environments/mpe/simple_push.py
@@ -5,7 +5,7 @@
from functools import partial
from jaxmarl.environments.mpe.simple import SimpleMPE, State
from jaxmarl.environments.mpe.default_params import *
-from gymnax.environments.spaces import Box
+from jaxmarl.environments.spaces import Box
# Obstacle Colours
COLOUR_1 = jnp.array([0.1, 0.9, 0.1])
diff --git a/jaxmarl/environments/mpe/simple_reference.py b/jaxmarl/environments/mpe/simple_reference.py
index b86314bb..ae70e482 100644
--- a/jaxmarl/environments/mpe/simple_reference.py
+++ b/jaxmarl/environments/mpe/simple_reference.py
@@ -5,7 +5,7 @@
from functools import partial
from jaxmarl.environments.mpe.simple import SimpleMPE, State
from jaxmarl.environments.mpe.default_params import *
-from gymnax.environments.spaces import Box, Discrete
+from jaxmarl.environments.spaces import Box, Discrete
# Obstacle Colours
OBS_COLOUR = [(191, 64, 64), (64, 191, 64), (64, 64, 191)]
diff --git a/jaxmarl/environments/mpe/simple_speaker_listener.py b/jaxmarl/environments/mpe/simple_speaker_listener.py
index 8d1ee78e..c3ddeacb 100644
--- a/jaxmarl/environments/mpe/simple_speaker_listener.py
+++ b/jaxmarl/environments/mpe/simple_speaker_listener.py
@@ -4,7 +4,7 @@
from typing import Tuple, Dict
from jaxmarl.environments.mpe.simple import SimpleMPE, State
from jaxmarl.environments.mpe.default_params import *
-from gymnax.environments.spaces import Box, Discrete
+from jaxmarl.environments.spaces import Box, Discrete
SPEAKER = "speaker_0"
LISTENER = "listener_0"
diff --git a/jaxmarl/environments/mpe/simple_spread.py b/jaxmarl/environments/mpe/simple_spread.py
index ebabe61d..222c5818 100644
--- a/jaxmarl/environments/mpe/simple_spread.py
+++ b/jaxmarl/environments/mpe/simple_spread.py
@@ -5,7 +5,7 @@
from functools import partial
from jaxmarl.environments.mpe.simple import SimpleMPE, State
from jaxmarl.environments.mpe.default_params import *
-from gymnax.environments.spaces import Box
+from jaxmarl.environments.spaces import Box
class SimpleSpreadMPE(SimpleMPE):
diff --git a/jaxmarl/environments/mpe/simple_tag.py b/jaxmarl/environments/mpe/simple_tag.py
index bf9c0869..813b032e 100644
--- a/jaxmarl/environments/mpe/simple_tag.py
+++ b/jaxmarl/environments/mpe/simple_tag.py
@@ -4,7 +4,7 @@
from typing import Tuple, Dict
from functools import partial
from jaxmarl.environments.mpe.simple import SimpleMPE, State
-from gymnax.environments.spaces import Box
+from jaxmarl.environments.spaces import Box
from jaxmarl.environments.mpe.default_params import *
diff --git a/jaxmarl/environments/mpe/simple_world_comm.py b/jaxmarl/environments/mpe/simple_world_comm.py
index 4343e212..e22a2643 100644
--- a/jaxmarl/environments/mpe/simple_world_comm.py
+++ b/jaxmarl/environments/mpe/simple_world_comm.py
@@ -11,8 +11,7 @@
OBS_COLOUR,
)
from jaxmarl.environments.mpe.default_params import *
-from gymnax.environments.spaces import Box, Discrete
-
+from jaxmarl.environments.spaces import Box, Discrete
# NOTE food and forests are part of world.landmarks
diff --git a/jaxmarl/environments/spaces.py b/jaxmarl/environments/spaces.py
index 4320bd46..8f97d6c9 100644
--- a/jaxmarl/environments/spaces.py
+++ b/jaxmarl/environments/spaces.py
@@ -1,3 +1,4 @@
+""" Built off Gymnax spaces.py, this module contains jittable classes for action and observation spaces. """
from typing import Tuple, Union, Sequence
from collections import OrderedDict
import chex
diff --git a/pyproject.toml b/pyproject.toml
index 3f5bfbba..a7c5a865 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -8,7 +8,6 @@ include = ['jaxmarl*']
[tool.setuptools.dynamic]
version = {attr = "jaxmarl.__version__"}
-dependencies = {file = ["requirements/requirements.txt"]}
[project]
name = "jaxmarl"
@@ -17,7 +16,7 @@ description = "Multi-Agent Reinforcement Learning with JAX"
authors = [
{name = "Foerster Lab for AI Research", email = "arutherford@robots.ox.ac.uk"},
]
-dynamic = ["version", "dependencies"]
+dynamic = ["version"]
license = {file = "LICENSE"}
requires-python = ">=3.10"
classifiers = [
@@ -31,6 +30,35 @@ classifiers = [
"Topic :: Software Development :: Libraries :: Python Modules",
"License :: OSI Approved :: Apache Software License",
]
+dependencies = [
+ "jax>=0.4.16.0,<=0.4.25",
+ "jaxlib>=0.4.16.0,<=0.4.25",
+ "flax",
+ "safetensors",
+ "chex",
+ "brax==0.10.3",
+ "mujoco==3.1.3",
+ "matplotlib",
+ "pillow",
+ "scipy<=1.12",
+ "gymnax",
+]
+
+[project.optional-dependencies]
+algs = [
+ "optax",
+ "distrax",
+ "flashbax==0.1.0",
+ "wandb",
+ "hydra-core>=1.3.2",
+ "omegaconf>=2.3.0",
+ "pettingzoo>=1.24.3",
+ "tqdm>=4.66.0",
+]
+dev = [
+ "pytest",
+ "pygame",
+]
[project.urls]
"Homepage" = "https://github.com/FLAIROx/JaxMARL"
diff --git a/requirements/requirements.txt b/requirements/requirements.txt
deleted file mode 100644
index 2e0fd4b8..00000000
--- a/requirements/requirements.txt
+++ /dev/null
@@ -1,26 +0,0 @@
-# requirements are alligned with nvcr.io/nvidia/jax:23.10-py3 image
-jax>=0.4.16.0,<=0.4.25
-jaxlib>=0.4.16.0,<=0.4.25
-flax==0.7.4
-chex==0.1.84
-optax==0.1.7
-dotmap==1.3.30
-evosax==0.1.5
-distrax==0.1.5
-brax==0.10.3
-mujoco==3.1.3
-gymnax==0.0.6
-safetensors==0.4.2
-flashbax==0.1.0
-# less sensitive libs
-wandb
-pytest
-pygame
-numpy>=1.26.1
-hydra-core>=1.3.2
-omegaconf>=2.3.0
-matplotlib>=3.8.3
-pillow>=10.2.0
-pettingzoo>=1.24.3
-tqdm>=4.66.0
-scipy<=1.12
diff --git a/tests/hanabi/test_hanabi.py b/tests/hanabi/test_hanabi.py
index 4c76e1a5..b6fcfb62 100644
--- a/tests/hanabi/test_hanabi.py
+++ b/tests/hanabi/test_hanabi.py
@@ -4,7 +4,6 @@
import jax
from jax import numpy as jnp
from jaxmarl import make
-from jaxmarl.wrappers.baselines import LogWrapper
env = make("hanabi")
dir_path = os.path.dirname(os.path.realpath(__file__))