Skip to content

MichaelTMatthews/Craftax

Repository files navigation

Update: Craftax was accepted at ICML 2024 as a spotlight!

⛏️ Craftax

Craftax is an RL environment written entirely in JAX. Craftax reimplements and significantly extends the game mechanics of Crafter, taking inspiration from roguelike games such as NetHack. Craftax conforms to the gymnax interface, allowing easy integration with existing JAX-based frameworks like PureJaxRL and JaxUED.

📜 Basic Usage

Craftax conforms to the gymnax interface:

rng = jax.random.PRNGKey(0)
rng, _rng = jax.random.split(rng)
rngs = jax.random.split(_rng, 3)

# Create environment
env = make_craftax_env_from_name("Craftax-Symbolic-v1", auto_reset=True)
env_params = env.default_params

# Get an initial state and observation
obs, state = env.reset(rngs[0], env_params)

# Pick random action
action = env.action_space(env_params).sample(rngs[1])

# Step environment
obs, state, reward, done, info = env.step(rngs[2], state, action, env_params)

⬇️ Installation

The latest Craftax release can be installed from PyPi:

pip install craftax

If you want the most recent commit instead use:

pip install git+https://github.com/MichaelTMatthews/Craftax.git@main

Extending Craftax

If you want to extend Craftax, run (make sure you have pip>=23.0):

git clone https://github.com/MichaelTMatthews/Craftax.git
cd Craftax
pip install -e ".[dev]"
pre-commit install

GPU-Enabled JAX

By default, both of the above methods will install JAX on the CPU. If you want to run JAX on a GPU/TPU, you'll need to install the correct wheel for your system from JAX. For NVIDIA GPU the command is:

pip install -U "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

🎮 Play

To play Craftax run:

play_craftax

or to play Craftax-Classic run:

play_craftax_classic

Since Craftax runs entirely in JAX, it will take some time to compile the rendering and step functions - it might take around 30s to render the first frame and then another 20s to take the first action. After this it should be very quick. A tutorial for how to beat the game is present in tutorial.md. The controls are printed out at the beginning of play.

📈 Experiment

To run experiments see the Craftax Baselines repository.

🔪 Gotchas

Optimistic Resets

Craftax provides the option to use optimistic resets to improve performance, which means that we provide access to environments that do not auto-reset. Environments obtained from make_craftax_env_from_name or make_craftax_env_from_args with auto_reset=False will not automatically reset and if not properly handled will continue episodes into invalid states. These environments should always be wrapped either in OptimisticResetVecEnvWrapper(for efficient resets) or AutoResetEnvWrapper (to recover the default gymnax auto-reset behaviour). See ppo.py in Craftax Baselines for correct usage. Using auto_reset=True will return a regular auto-reset environment, which can be treated like any other gymnax environment.

Texture Caching

We use a texture cache to avoid recreating the texture atlas every time Craftax is imported. If you are just running Craftax as a benchmark this will not affect you. However, if you are editing the game (e.g. adding new blocks, entities etc.) then a stale cache could cause errors. You can export the following environment variable to force textures to be created from scratch every run.

export CRAFTAX_RELOAD_TEXTURES=true

📋 Scoreboard

If you would like to add an algorithm please open a PR and provide a reference to the source of the results. We report reward as a % of the maximum (226).

Craftax-1B

Algorithm Reward (% max) Code Paper
PPO-GTrXL 18.3 TransformerXL_PPO_JAX GTrXL
PPO-RNN 15.3 Craftax_Baselines PPO
RND 12.0 Craftax_Baselines RND
PPO 11.9 Craftax_Baselines PPO
ICM 11.9 Craftax_Baselines ICM
E3B 11.0 Craftax_Baselines E3B

Craftax-1M

Algorithm Reward (% max) Code Paper
PPO-RNN 2.3 Craftax_Baselines PPO
RND 2.2 Craftax_Baselines RND
PPO 2.2 Craftax_Baselines PPO
ICM 2.2 Craftax_Baselines ICM
E3B 2.2 Craftax_Baselines E3B

💾 Offline Dataset

A small dataset of mixed-skill human trajectories is available here. Once the zip file has been extracted, the trajectories can be loaded with the load_compressed_pickle function. These were gathered on an earlier version of Craftax and it is recommended you use v1.1.0 or earlier to investigate them. run1 is the only trajectory to complete the game.

🔎 See Also

  • ⛏️ Crafter The original Crafter benchmark.
  • ⚔️ NLE NetHack as an RL environment.
  • PureJaxRL End-to-end RL implementations in Jax.
  • 🌎 JaxUED: CleanRL style UED implementations in Jax.
  • 🌍 Minimax: Modular UED implementations in Jax.
  • 🏋️ Gymnax: Standard Jax RL interface with classic environments.
  • 🧑‍🤝‍🧑 JaxMARL: Multi-agent RL in Jax.

📚 Citation

If you use Craftax in your work please cite it as follows:

@inproceedings{matthews2024craftax,
    author={Michael Matthews and Michael Beukman and Benjamin Ellis and Mikayel Samvelyan and Matthew Jackson and Samuel Coward and Jakob Foerster},
    title = {Craftax: A Lightning-Fast Benchmark for Open-Ended Reinforcement Learning},
    booktitle = {International Conference on Machine Learning ({ICML})},
    year = {2024}
}