Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Supporting gym version 0.21.0 #31

Merged
merged 3 commits into from
Nov 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 27 additions & 9 deletions minihack/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import numpy as np
import pkg_resources
from nle import _pynethack, nethack
from nle.nethack.nethack import SCREEN_DESCRIPTIONS_SHAPE
from nle.nethack.nethack import SCREEN_DESCRIPTIONS_SHAPE, OBSERVATION_DESC
from nle.env.base import FULL_ACTIONS, NLE_SPACE_ITEMS
from nle.env.tasks import NetHackStaircase
from minihack.wiki import NetHackWiki
Expand Down Expand Up @@ -50,28 +50,46 @@

MINIHACK_SPACE_FUNCS = {
"glyphs_crop": lambda x, y: gym.spaces.Box(
low=0, high=nethack.MAX_GLYPH, shape=(x, y), dtype=np.uint16
low=0,
high=nethack.MAX_GLYPH,
shape=(x, y),
dtype=OBSERVATION_DESC["glyphs"]["dtype"],
),
"chars_crop": lambda x, y: gym.spaces.Box(
low=0, high=255, shape=(x, y), dtype=np.uint8
low=0,
high=255,
shape=(x, y),
dtype=OBSERVATION_DESC["chars"]["dtype"],
),
"colors_crop": lambda x, y: gym.spaces.Box(
low=0, high=15, shape=(x, y), dtype=np.uint8
low=0,
high=15,
shape=(x, y),
dtype=OBSERVATION_DESC["colors"]["dtype"],
),
"specials_crop": lambda x, y: gym.spaces.Box(
low=0, high=255, shape=(x, y), dtype=np.uint8
low=0,
high=255,
shape=(x, y),
dtype=OBSERVATION_DESC["specials"]["dtype"],
),
"tty_chars_crop": lambda x, y: gym.spaces.Box(
low=0, high=255, shape=(x, y), dtype=np.uint8
low=0,
high=255,
shape=(x, y),
dtype=OBSERVATION_DESC["tty_chars"]["dtype"],
),
"tty_colors_crop": lambda x, y: gym.spaces.Box(
low=0, high=31, shape=(x, y), dtype=np.uint8
low=0,
high=31,
shape=(x, y),
dtype=OBSERVATION_DESC["tty_colors"]["dtype"],
),
"screen_descriptions_crop": lambda x, y: gym.spaces.Box(
low=0,
high=127,
shape=(x, y, _pynethack.nethack.NLE_SCREEN_DESCRIPTION_LENGTH),
dtype=np.uint8,
dtype=OBSERVATION_DESC["screen_descriptions"]["dtype"],
),
"pixel_crop": lambda x, y: gym.spaces.Box(
low=0,
Expand Down Expand Up @@ -270,7 +288,7 @@ def _get_obs_space_dict(self, space_dict):
)
else:
if "pixel" in self._minihack_obs_keys:
d_shape = nethack.OBSERVATION_DESC["glyphs"]["shape"]
d_shape = OBSERVATION_DESC["glyphs"]["shape"]
shape = (
d_shape[0] * N_TILE_PIXEL,
d_shape[1] * N_TILE_PIXEL,
Expand Down
13 changes: 13 additions & 0 deletions minihack/envs/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,14 @@
# Copyright (c) Facebook, Inc. and its affiliates.

import gym
from gym.envs import registration


def register(id, **kwargs):
if gym.__version__ >= "0.21":
# Starting with version 0.21, gym wraps everything by the
# OrderEnforcing wrapper by default (which isn't in gym.wrappers).
# This breaks our seed() calls and some other code. Disable.
kwargs["order_enforce"] = False

registration.register(id, **kwargs)
8 changes: 4 additions & 4 deletions minihack/envs/boxohack.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
import pkg_resources
from nle import nethack
from gym.envs import registration
from minihack.envs import register
from minihack import LevelGenerator, MiniHackNavigation

LEVELS_PATH = os.path.join(
Expand Down Expand Up @@ -161,15 +161,15 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)


registration.register(
register(
id="MiniHack-Boxoban-Unfiltered-v0",
entry_point="minihack.envs.boxohack:MiniHackBoxobanUnfiltered",
)
registration.register(
register(
id="MiniHack-Boxoban-Medium-v0",
entry_point="minihack.envs.boxohack:MiniHackBoxobanMedium",
)
registration.register(
register(
id="MiniHack-Boxoban-Hard-v0",
entry_point="minihack.envs.boxohack:MiniHackBoxobanHard",
)
8 changes: 4 additions & 4 deletions minihack/envs/corridor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates.
from minihack import MiniHackNavigation
from gym.envs import registration
from minihack.envs import register
from nle import nethack

MOVE_ACTIONS = tuple(nethack.CompassDirection)
Expand Down Expand Up @@ -39,15 +39,15 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, des_file="corridor5.des", **kwargs)


registration.register(
register(
id="MiniHack-Corridor-R2-v0",
entry_point="minihack.envs.corridor:MiniHackCorridor2",
)
registration.register(
register(
id="MiniHack-Corridor-R3-v0",
entry_point="minihack.envs.corridor:MiniHackCorridor3",
)
registration.register(
register(
id="MiniHack-Corridor-R5-v0",
entry_point="minihack.envs.corridor:MiniHackCorridor5",
)
10 changes: 5 additions & 5 deletions minihack/envs/exploremaze.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) Facebook, Inc. and its affiliates.
from gym.envs import registration
from minihack.envs import register
from minihack import MiniHackNavigation
from minihack.envs.corridor import NAVIGATE_ACTIONS
from minihack.reward_manager import RewardManager
Expand Down Expand Up @@ -70,19 +70,19 @@ def __init__(self, *args, **kwargs):
)


registration.register(
register(
id="MiniHack-ExploreMaze-Easy-v0",
entry_point="minihack.envs.exploremaze:MiniHackExploreMazeEasy",
)
registration.register(
register(
id="MiniHack-ExploreMaze-Hard-v0",
entry_point="minihack.envs.exploremaze:MiniHackExploreMazeHard",
)
registration.register(
register(
id="MiniHack-ExploreMaze-Easy-Mapped-v0",
entry_point="minihack.envs.exploremaze:MiniHackExploreMazeEasyMapped",
)
registration.register(
register(
id="MiniHack-ExploreMaze-Hard-Mapped-v0",
entry_point="minihack.envs.exploremaze:MiniHackExploreMazeHardMapped",
)
6 changes: 3 additions & 3 deletions minihack/envs/fightcorridor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates.
from minihack import MiniHackNavigation, LevelGenerator
from gym.envs import registration
from minihack.envs import register


class MiniHackFightCorridor(MiniHackNavigation):
Expand Down Expand Up @@ -33,12 +33,12 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, lit=False, **kwargs)


registration.register(
register(
id="MiniHack-CorridorBattle-v0",
entry_point="minihack.envs.fightcorridor:MiniHackFightCorridor",
)

registration.register(
register(
id="MiniHack-CorridorBattle-Dark-v0",
entry_point="minihack.envs.fightcorridor:MiniHackFightCorridorDark",
)
10 changes: 5 additions & 5 deletions minihack/envs/hidenseek.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates.
from minihack import MiniHackNavigation
from gym.envs import registration
from minihack.envs import register


class MiniHackHideAndSeekMapped(MiniHackNavigation):
Expand All @@ -27,19 +27,19 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, des_file="hidenseek_big.des", **kwargs)


registration.register(
register(
id="MiniHack-HideNSeek-Mapped-v0",
entry_point="minihack.envs.hidenseek:MiniHackHideAndSeekMapped",
)
registration.register(
register(
id="MiniHack-HideNSeek-v0",
entry_point="minihack.envs.hidenseek:MiniHackHideAndSeek",
)
registration.register(
register(
id="MiniHack-HideNSeek-Lava-v0",
entry_point="minihack.envs.hidenseek:MiniHackHideAndSeekLava",
)
registration.register(
register(
id="MiniHack-HideNSeek-Big-v0",
entry_point="minihack.envs.hidenseek:MiniHackHideAndSeekBig",
)
12 changes: 6 additions & 6 deletions minihack/envs/keyroom.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates.
from minihack import MiniHackNavigation
from minihack.level_generator import PATH_DAT_DIR
from gym.envs import registration
from minihack.envs import register
from nle.nethack import Command
from nle import nethack
import os
Expand Down Expand Up @@ -112,23 +112,23 @@ def __init__(self, *args, **kwargs):
)


registration.register(
register(
id="MiniHack-KeyRoom-Fixed-S5-v0",
entry_point="minihack.envs.keyroom:MiniHackKeyRoom5x5Fixed",
)
registration.register(
register(
id="MiniHack-KeyRoom-S5-v0",
entry_point="minihack.envs.keyroom:MiniHackKeyRoom5x5",
)
registration.register(
register(
id="MiniHack-KeyRoom-S15-v0",
entry_point="minihack.envs.keyroom:MiniHackKeyRoom15x15",
)
registration.register(
register(
id="MiniHack-KeyRoom-Dark-S5-v0",
entry_point="minihack.envs.keyroom:MiniHackKeyRoom5x5Dark",
)
registration.register(
register(
id="MiniHack-KeyRoom-Dark-S15-v0",
entry_point="minihack.envs.keyroom:MiniHackKeyRoom15x15Dark",
)
6 changes: 3 additions & 3 deletions minihack/envs/lab.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates.
from minihack import MiniHackNavigation, LevelGenerator
from gym.envs import registration
from minihack.envs import register


class MiniHackLabyrinth(MiniHackNavigation):
Expand Down Expand Up @@ -71,12 +71,12 @@ def __init__(self, *args, **kwargs):
)


registration.register(
register(
id="MiniHack-Labyrinth-Big-v0",
entry_point="minihack.envs.lab:MiniHackLabyrinth",
)

registration.register(
register(
id="MiniHack-Labyrinth-Small-v0",
entry_point="minihack.envs.lab:MiniHackLabyrinthSmall",
)
14 changes: 7 additions & 7 deletions minihack/envs/mazewalk.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates.
from minihack import MiniHackNavigation
from minihack.level_generator import LevelGenerator
from gym.envs import registration
from minihack.envs import register

DUNGEON_SHAPE = (76, 21)

Expand Down Expand Up @@ -63,27 +63,27 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, w=45, h=19, premapped=True, **kwargs)


registration.register(
register(
id="MiniHack-MazeWalk-9x9-v0",
entry_point="minihack.envs.mazewalk:MiniHackMazeWalk9x9",
)
registration.register(
register(
id="MiniHack-MazeWalk-Mapped-9x9-v0",
entry_point="minihack.envs.mazewalk:MiniHackMazeWalk9x9Premapped",
)
registration.register(
register(
id="MiniHack-MazeWalk-15x15-v0",
entry_point="minihack.envs.mazewalk:MiniHackMazeWalk15x15",
)
registration.register(
register(
id="MiniHack-MazeWalk-Mapped-15x15-v0",
entry_point="minihack.envs.mazewalk:MiniHackMazeWalk15x15Premapped",
)
registration.register(
register(
id="MiniHack-MazeWalk-45x19-v0",
entry_point="minihack.envs.mazewalk:MiniHackMazeWalk45x19",
)
registration.register(
register(
id="MiniHack-MazeWalk-Mapped-45x19-v0",
entry_point="minihack.envs.mazewalk:MiniHackMazeWalk45x19Premapped",
)
8 changes: 4 additions & 4 deletions minihack/envs/memento.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright (c) Facebook, Inc. and its affiliates.
from gym.envs import registration
from minihack.envs import register
from minihack import MiniHackNavigation, RewardManager


Expand Down Expand Up @@ -41,18 +41,18 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, des_file="memento_hard.des", **kwargs)


registration.register(
register(
id="MiniHack-Memento-Short-F2-v0",
entry_point="minihack.envs.memento:MiniHackMementoShortF2",
)


registration.register(
register(
id="MiniHack-Memento-F2-v0",
entry_point="minihack.envs.memento:MiniHackMementoF2",
)

registration.register(
register(
id="MiniHack-Memento-F4-v0",
entry_point="minihack.envs.memento:MiniHackMementoF4",
)
Loading