From 6956efd1ae7d1bd0c6694e42059951536c44108d Mon Sep 17 00:00:00 2001 From: Stewy Slocum Date: Fri, 7 Oct 2022 17:06:51 -0400 Subject: [PATCH 1/9] Add score masking to seven atari environments --- src/seals/atari.py | 43 +++++++++++++++++++++++++++++++++++++------ src/seals/util.py | 43 ++++++++++++++++++++++++++++++++++++++++++- tests/test_envs.py | 13 ++++++++++++- 3 files changed, 91 insertions(+), 8 deletions(-) diff --git a/src/seals/atari.py b/src/seals/atari.py index ff01024..37bf004 100644 --- a/src/seals/atari.py +++ b/src/seals/atari.py @@ -1,15 +1,44 @@ """Adaptation of Atari environments for specification learning algorithms.""" -from typing import Iterable +from typing import Dict, Iterable, List, Optional import gym -from seals.util import AutoResetWrapper, get_gym_max_episode_steps +from seals.util import AutoResetWrapper, MaskScoreWrapper, get_gym_max_episode_steps +SCORE_REGIONS: Dict[str, List[Dict[str, int]]] = { + "BeamRider": [ + dict(x0=5, x1=20, y0=45, y1=120), + dict(x0=28, x1=40, y0=15, y1=40), + ], + "Breakout": [dict(x0=0, x1=16, y0=35, y1=80)], + "Enduro": [ + dict(x0=163, x1=173, y0=55, y1=110), + dict(x0=177, x1=188, y0=68, y1=107), + ], + "Pong": [dict(x0=0, x1=24, y0=0, y1=160)], + "Qbert": [dict(x0=6, x1=15, y0=33, y1=71)], + "Seaquest": [dict(x0=7, x1=19, y0=80, y1=110)], + "SpaceInvaders": [dict(x0=10, x1=20, y0=0, y1=160)], +} -def fixed_length_atari(atari_env_id: str) -> gym.Env: - """Fixed-length variant of a given Atari environment.""" - return AutoResetWrapper(gym.make(atari_env_id)) + +def _get_score_region(atari_env_id: str) -> Optional[List[Dict[str, int]]]: + basename = atari_env_id.split("/")[-1].split("-")[0] + return SCORE_REGIONS.get(basename) + + +def make_atari_env(atari_env_id: str) -> gym.Env: + """Fixed-length, masked-score variant of a given Atari environment.""" + score_region = _get_score_region(atari_env_id) + if score_region is None: + raise ValueError( + "Requested environment not supported. " + + "See https://github.com/HumanCompatibleAI/seals/issues/61.", + ) + + env = MaskScoreWrapper(gym.make(atari_env_id), score_region) + return AutoResetWrapper(env) def _not_ram_or_det(env_id: str) -> bool: @@ -30,10 +59,12 @@ def _supported_atari_env(gym_spec: gym.envs.registration.EnvSpec) -> bool: is_atari = gym_spec.entry_point == "gym.envs.atari:AtariEnv" v5_and_plain = gym_spec.id.endswith("-v5") and not ("NoFrameskip" in gym_spec.id) v4_and_no_frameskip = gym_spec.id.endswith("-v4") and "NoFrameskip" in gym_spec.id + score_regions_available = _get_score_region(gym_spec.id) is not None return ( is_atari and _not_ram_or_det(gym_spec.id) and (v5_and_plain or v4_and_no_frameskip) + and score_regions_available ) @@ -50,7 +81,7 @@ def register_atari_envs( for gym_spec in gym_atari_env_specs: gym.register( id=_seals_name(gym_spec), - entry_point="seals.atari:fixed_length_atari", + entry_point="seals.atari:make_atari_env", max_episode_steps=get_gym_max_episode_steps(gym_spec.id), kwargs=dict(atari_env_id=gym_spec.id), ) diff --git a/src/seals/util.py b/src/seals/util.py index e80ceca..5a2d226 100644 --- a/src/seals/util.py +++ b/src/seals/util.py @@ -1,6 +1,6 @@ """Miscellaneous utilities.""" -from typing import Optional, Tuple +from typing import Dict, List, Optional, Sequence, Tuple, Union import gym import numpy as np @@ -23,6 +23,47 @@ def step(self, action): return obs, rew, False, info +class MaskScoreWrapper(gym.Wrapper): + """Mask a list of box-shaped regions in the observation to hide reward info. + + Intended for environments whose observations are raw pixels (like atari + environments). Used to mask regions of the observation that include information + that could be used to infer the reward, like the game score or enemy ship count. + """ + + def __init__( + self, + env: gym.Env, + score_regions: List[Dict[str, int]], + fill_value: Union[float, Sequence[float]] = 0, + ): + """Builds MaskScoreWrapper. + + Args: + env: The environment to wrap. + score_regions: A list of box-shaped regions to mask, each denoted by + a dictionary `{"x0": x0, "x1": x1, "y0": y0, "y1": y1}`. + fill_value: The fill_value for the masked region. By default is black. + Can support RGB colors by being a sequence of values [r, g, b]. + """ + super().__init__(env) + self.fill_value = np.array(fill_value, env.observation_space.dtype) + + self.mask = np.ones(env.observation_space.shape, dtype=bool) + for r in score_regions: + self.mask[r["x0"] : r["x1"], r["y0"] : r["y1"]] = 0 + + def step(self, action): + """Returns (obs, rew, done, info) with masked obs.""" + obs, rew, done, info = self.env.step(action) + return np.where(self.mask, obs, self.fill_value), rew, done, info + + def reset(self, **kwargs): + """Returns masked reset observation.""" + obs = self.env.reset(**kwargs) + return np.where(self.mask, obs, self.fill_value) + + class ObsCastWrapper(gym.Wrapper): """Cast observations to specified dtype. diff --git a/tests/test_envs.py b/tests/test_envs.py index c33e176..f729195 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -7,7 +7,7 @@ import pytest import seals # noqa: F401 required for env registration -from seals.atari import _seals_name +from seals.atari import _get_score_region, _seals_name from seals.testing import envs ENV_NAMES: List[str] = [ @@ -56,6 +56,17 @@ def test_atari_space_invaders(): assert len(space_invader_environments) > 0 +def test_no_atari_unmasked(): + """Tests that we only load Atari envs with score masking implemented.""" + non_masked_environments = list( + filter( + lambda name: _get_score_region(name) is None, + ATARI_ENVS, + ), + ) + assert len(non_masked_environments) == 0 + + @pytest.mark.parametrize("env_name", ENV_NAMES) class TestEnvs: """Battery of simple tests for environments.""" From a97839fe47c8a1ad3421b02e14f7679bfe8c96db Mon Sep 17 00:00:00 2001 From: Stewy Slocum Date: Mon, 10 Oct 2022 18:19:14 -0400 Subject: [PATCH 2/9] Add option of masked or unmasked atari envs --- src/seals/atari.py | 72 +++++++++++++++++++++++++++------------------- src/seals/util.py | 15 ++++++---- tests/test_envs.py | 32 ++++++++++++++------- 3 files changed, 74 insertions(+), 45 deletions(-) diff --git a/src/seals/atari.py b/src/seals/atari.py index 37bf004..739e5be 100644 --- a/src/seals/atari.py +++ b/src/seals/atari.py @@ -1,44 +1,48 @@ """Adaptation of Atari environments for specification learning algorithms.""" -from typing import Dict, Iterable, List, Optional +from typing import Dict, Iterable, List, Optional, Tuple import gym from seals.util import AutoResetWrapper, MaskScoreWrapper, get_gym_max_episode_steps -SCORE_REGIONS: Dict[str, List[Dict[str, int]]] = { +SCORE_REGIONS: Dict[str, List[Dict[str, Tuple[int, int]]]] = { "BeamRider": [ - dict(x0=5, x1=20, y0=45, y1=120), - dict(x0=28, x1=40, y0=15, y1=40), + dict(x=(5, 20), y=(45, 120)), + dict(x=(28, 40), y=(15, 40)), ], - "Breakout": [dict(x0=0, x1=16, y0=35, y1=80)], + "Breakout": [dict(x=(0, 16), y=(35, 80))], "Enduro": [ - dict(x0=163, x1=173, y0=55, y1=110), - dict(x0=177, x1=188, y0=68, y1=107), + dict(x=(163, 173), y=(55, 110)), + dict(x=(177, 188), y=(68, 107)), ], - "Pong": [dict(x0=0, x1=24, y0=0, y1=160)], - "Qbert": [dict(x0=6, x1=15, y0=33, y1=71)], - "Seaquest": [dict(x0=7, x1=19, y0=80, y1=110)], - "SpaceInvaders": [dict(x0=10, x1=20, y0=0, y1=160)], + "Pong": [dict(x=(0, 24), y=(0, 160))], + "Qbert": [dict(x=(6, 15), y=(33, 71))], + "Seaquest": [dict(x=(7, 19), y=(80, 110))], + "SpaceInvaders": [dict(x=(10, 20), y=(0, 160))], } -def _get_score_region(atari_env_id: str) -> Optional[List[Dict[str, int]]]: +def _get_score_region(atari_env_id: str) -> Optional[List[Dict[str, Tuple[int, int]]]]: basename = atari_env_id.split("/")[-1].split("-")[0] + basename = basename.replace("NoFrameskip", "") return SCORE_REGIONS.get(basename) -def make_atari_env(atari_env_id: str) -> gym.Env: - """Fixed-length, masked-score variant of a given Atari environment.""" - score_region = _get_score_region(atari_env_id) - if score_region is None: - raise ValueError( - "Requested environment not supported. " - + "See https://github.com/HumanCompatibleAI/seals/issues/61.", - ) +def make_atari_env(atari_env_id: str, masked: bool) -> gym.Env: + """Fixed-length, optionally masked-score variant of a given Atari environment.""" + env = AutoResetWrapper(gym.make(atari_env_id)) + + if masked: + score_region = _get_score_region(atari_env_id) + if score_region is None: + raise ValueError( + "Requested environment does not yet support masking. " + + "See https://github.com/HumanCompatibleAI/seals/issues/61.", + ) + env = MaskScoreWrapper(env, score_region) - env = MaskScoreWrapper(gym.make(atari_env_id), score_region) - return AutoResetWrapper(env) + return env def _not_ram_or_det(env_id: str) -> bool: @@ -59,29 +63,39 @@ def _supported_atari_env(gym_spec: gym.envs.registration.EnvSpec) -> bool: is_atari = gym_spec.entry_point == "gym.envs.atari:AtariEnv" v5_and_plain = gym_spec.id.endswith("-v5") and not ("NoFrameskip" in gym_spec.id) v4_and_no_frameskip = gym_spec.id.endswith("-v4") and "NoFrameskip" in gym_spec.id - score_regions_available = _get_score_region(gym_spec.id) is not None return ( is_atari and _not_ram_or_det(gym_spec.id) and (v5_and_plain or v4_and_no_frameskip) - and score_regions_available ) -def _seals_name(gym_spec: gym.envs.registration.EnvSpec) -> str: +def _seals_name(gym_spec: gym.envs.registration.EnvSpec, masked: bool) -> str: """Makes a Gym ID for an Atari environment in the seals namespace.""" slash_separated = gym_spec.id.split("/") - return "seals/" + slash_separated[-1] + name = "seals/" + slash_separated[-1] + + if not masked: + last_hyphen_idx = name.rfind("-") + name = name[:last_hyphen_idx] + "-Unmasked" + name[last_hyphen_idx:] + return name def register_atari_envs( gym_atari_env_specs: Iterable[gym.envs.registration.EnvSpec], ) -> None: - """Register wrapped gym Atari environments.""" + """Register masked and unmasked wrapped gym Atari environments.""" for gym_spec in gym_atari_env_specs: gym.register( - id=_seals_name(gym_spec), + id=_seals_name(gym_spec, masked=False), entry_point="seals.atari:make_atari_env", max_episode_steps=get_gym_max_episode_steps(gym_spec.id), - kwargs=dict(atari_env_id=gym_spec.id), + kwargs=dict(atari_env_id=gym_spec.id, masked=False), ) + if _get_score_region(gym_spec.id) is not None: + gym.register( + id=_seals_name(gym_spec, masked=True), + entry_point="seals.atari:make_atari_env", + max_episode_steps=get_gym_max_episode_steps(gym_spec.id), + kwargs=dict(atari_env_id=gym_spec.id, masked=True), + ) diff --git a/src/seals/util.py b/src/seals/util.py index 5a2d226..eba279c 100644 --- a/src/seals/util.py +++ b/src/seals/util.py @@ -34,7 +34,7 @@ class MaskScoreWrapper(gym.Wrapper): def __init__( self, env: gym.Env, - score_regions: List[Dict[str, int]], + score_regions: List[Dict[str, Tuple[int, int]]], fill_value: Union[float, Sequence[float]] = 0, ): """Builds MaskScoreWrapper. @@ -42,7 +42,8 @@ def __init__( Args: env: The environment to wrap. score_regions: A list of box-shaped regions to mask, each denoted by - a dictionary `{"x0": x0, "x1": x1, "y0": y0, "y1": y1}`. + a dictionary `{"x": (x0, x1), "y": (y0, y1)}`, where `x0 < x1` + and `y0 < y1`. fill_value: The fill_value for the masked region. By default is black. Can support RGB colors by being a sequence of values [r, g, b]. """ @@ -51,17 +52,21 @@ def __init__( self.mask = np.ones(env.observation_space.shape, dtype=bool) for r in score_regions: - self.mask[r["x0"] : r["x1"], r["y0"] : r["y1"]] = 0 + assert r["x"][0] < r["x"][1] and r["y"][0] < r["y"][1] + self.mask[r["x"][0] : r["x"][1], r["y"][0] : r["y"][1]] = 0 + + def _mask_obs(self, obs): + return np.where(self.mask, obs, self.fill_value) def step(self, action): """Returns (obs, rew, done, info) with masked obs.""" obs, rew, done, info = self.env.step(action) - return np.where(self.mask, obs, self.fill_value), rew, done, info + return self._mask_obs(obs), rew, done, info def reset(self, **kwargs): """Returns masked reset observation.""" obs = self.env.reset(**kwargs) - return np.where(self.mask, obs, self.fill_value) + return self._mask_obs(obs) class ObsCastWrapper(gym.Wrapper): diff --git a/tests/test_envs.py b/tests/test_envs.py index f729195..cbcf423 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -26,7 +26,11 @@ ] ATARI_ENVS: List[str] = [ - _seals_name(gym_spec) for gym_spec in seals.GYM_ATARI_ENV_SPECS + _seals_name(gym_spec, masked=False) for gym_spec in seals.GYM_ATARI_ENV_SPECS +] + [ + _seals_name(gym_spec, masked=True) + for gym_spec in seals.GYM_ATARI_ENV_SPECS + if _get_score_region(gym_spec.id) is not None ] ATARI_V5_ENVS: List[str] = list(filter(lambda name: name.endswith("-v5"), ATARI_ENVS)) @@ -46,25 +50,31 @@ def test_some_atari_envs(): def test_atari_space_invaders(): - """Tests if there's an Atari environment called space invaders.""" - space_invader_environments = list( + """Tests for masked and unmasked Atari space invaders environments.""" + masked_space_invader_environments = list( filter( - lambda name: "SpaceInvaders" in name, + lambda name: "SpaceInvaders" in name and "Unmasked" not in name, ATARI_ENVS, ), ) - assert len(space_invader_environments) > 0 + assert len(masked_space_invader_environments) > 0 - -def test_no_atari_unmasked(): - """Tests that we only load Atari envs with score masking implemented.""" - non_masked_environments = list( + unmasked_space_invader_environments = list( filter( - lambda name: _get_score_region(name) is None, + lambda name: "SpaceInvaders" in name and "Unmasked" in name, ATARI_ENVS, ), ) - assert len(non_masked_environments) == 0 + assert len(unmasked_space_invader_environments) > 0 + + +def test_atari_unmasked_env_naming(): + """Tests that all unmasked Atari envs have the appropriate name qualifier.""" + noncompliant_envs = [ + (_get_score_region(name) is None and "Unmasked" not in name) + for name in ATARI_ENVS + ] + assert len(noncompliant_envs) == 0 @pytest.mark.parametrize("env_name", ENV_NAMES) From b2f7d96a033a4fb739a12fd87b0231274b08927c Mon Sep 17 00:00:00 2001 From: Stewy Slocum Date: Sun, 16 Oct 2022 21:16:08 -0400 Subject: [PATCH 3/9] add final tests and cosmetic changes for masked score atari environments --- src/seals/atari.py | 21 ++++++++++----------- src/seals/util.py | 6 +++++- tests/test_envs.py | 18 +++++++++++++++--- tests/test_util.py | 13 ++++++++++++- 4 files changed, 42 insertions(+), 16 deletions(-) diff --git a/src/seals/atari.py b/src/seals/atari.py index 739e5be..190fc2e 100644 --- a/src/seals/atari.py +++ b/src/seals/atari.py @@ -38,7 +38,7 @@ def make_atari_env(atari_env_id: str, masked: bool) -> gym.Env: if score_region is None: raise ValueError( "Requested environment does not yet support masking. " - + "See https://github.com/HumanCompatibleAI/seals/issues/61.", + "See https://github.com/HumanCompatibleAI/seals/issues/61.", ) env = MaskScoreWrapper(env, score_region) @@ -76,7 +76,7 @@ def _seals_name(gym_spec: gym.envs.registration.EnvSpec, masked: bool) -> str: name = "seals/" + slash_separated[-1] if not masked: - last_hyphen_idx = name.rfind("-") + last_hyphen_idx = name.rfind("-v") name = name[:last_hyphen_idx] + "-Unmasked" + name[last_hyphen_idx:] return name @@ -85,17 +85,16 @@ def register_atari_envs( gym_atari_env_specs: Iterable[gym.envs.registration.EnvSpec], ) -> None: """Register masked and unmasked wrapped gym Atari environments.""" - for gym_spec in gym_atari_env_specs: + + def register_gym(masked): gym.register( - id=_seals_name(gym_spec, masked=False), + id=_seals_name(gym_spec, masked=masked), entry_point="seals.atari:make_atari_env", max_episode_steps=get_gym_max_episode_steps(gym_spec.id), - kwargs=dict(atari_env_id=gym_spec.id, masked=False), + kwargs=dict(atari_env_id=gym_spec.id, masked=masked), ) + + for gym_spec in gym_atari_env_specs: + register_gym(masked=False) if _get_score_region(gym_spec.id) is not None: - gym.register( - id=_seals_name(gym_spec, masked=True), - entry_point="seals.atari:make_atari_env", - max_episode_steps=get_gym_max_episode_steps(gym_spec.id), - kwargs=dict(atari_env_id=gym_spec.id, masked=True), - ) + register_gym(masked=True) diff --git a/src/seals/util.py b/src/seals/util.py index eba279c..d0a3b46 100644 --- a/src/seals/util.py +++ b/src/seals/util.py @@ -46,13 +46,17 @@ def __init__( and `y0 < y1`. fill_value: The fill_value for the masked region. By default is black. Can support RGB colors by being a sequence of values [r, g, b]. + + Raises: + ValueError: If a score region does not conform to the spec. """ super().__init__(env) self.fill_value = np.array(fill_value, env.observation_space.dtype) self.mask = np.ones(env.observation_space.shape, dtype=bool) for r in score_regions: - assert r["x"][0] < r["x"][1] and r["y"][0] < r["y"][1] + if r["x"][0] >= r["x"][1] or r["y"][0] >= r["y"][1]: + raise ValueError('Invalid region: "x" and "y" must be increasing.') self.mask[r["x"][0] : r["x"][1], r["y"][0] : r["y"][1]] = 0 def _mask_obs(self, obs): diff --git a/tests/test_envs.py b/tests/test_envs.py index cbcf423..28e77cc 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -7,7 +7,7 @@ import pytest import seals # noqa: F401 required for env registration -from seals.atari import _get_score_region, _seals_name +from seals.atari import SCORE_REGIONS, _get_score_region, _seals_name from seals.testing import envs ENV_NAMES: List[str] = [ @@ -25,13 +25,15 @@ "seals/InitShiftTest-v0", ] -ATARI_ENVS: List[str] = [ +UNMASKED_ATARI_ENVS: List[str] = [ _seals_name(gym_spec, masked=False) for gym_spec in seals.GYM_ATARI_ENV_SPECS -] + [ +] +MASKED_ATARI_ENVS: List[str] = [ _seals_name(gym_spec, masked=True) for gym_spec in seals.GYM_ATARI_ENV_SPECS if _get_score_region(gym_spec.id) is not None ] +ATARI_ENVS = UNMASKED_ATARI_ENVS + MASKED_ATARI_ENVS ATARI_V5_ENVS: List[str] = list(filter(lambda name: name.endswith("-v5"), ATARI_ENVS)) ATARI_NO_FRAMESKIP_ENVS: List[str] = list( @@ -77,6 +79,16 @@ def test_atari_unmasked_env_naming(): assert len(noncompliant_envs) == 0 +def test_atari_masks_satisfy_spec(): + """Tests that all Atari masks satisfy the spec.""" + masks_satisfy_spec = [ + mask["x"][0] < mask["x"][1] and mask["y"][0] < mask["y"][1] + for env_regions in SCORE_REGIONS.values() + for mask in env_regions + ] + assert all(masks_satisfy_spec) + + @pytest.mark.parametrize("env_name", ENV_NAMES) class TestEnvs: """Battery of simple tests for environments.""" diff --git a/tests/test_util.py b/tests/test_util.py index 7ff126f..8b17695 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -2,9 +2,20 @@ import collections +import gym import numpy as np +import pytest -from seals import util +from seals import GYM_ATARI_ENV_SPECS, util + + +def test_mask_score_wrapper_enforces_spec(): + """Test that MaskScoreWrapper enforces the spec.""" + atari_env = gym.make(GYM_ATARI_ENV_SPECS[0].id) + with pytest.raises(): + util.MaskScoreWrapper(atari_env, [dict(x=(0, 1), y=(1, 0))]) + with pytest.raises(): + util.MaskScoreWrapper(atari_env, [dict(x=(1, 0), y=(0, 1))]) def test_sample_distribution(): From 31af01a8fa6e4dab729dc8b42c78fb2522d59e5d Mon Sep 17 00:00:00 2001 From: Stewy Slocum Date: Thu, 27 Oct 2022 23:44:27 -0400 Subject: [PATCH 4/9] Add pytest match option to MaskScoreWrapper tests --- tests/test_util.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/test_util.py b/tests/test_util.py index 8b17695..f7071a7 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -12,9 +12,10 @@ def test_mask_score_wrapper_enforces_spec(): """Test that MaskScoreWrapper enforces the spec.""" atari_env = gym.make(GYM_ATARI_ENV_SPECS[0].id) - with pytest.raises(): + desired_error_message = 'Invalid region: "x" and "y" must be increasing.' + with pytest.raises(ValueError, match=desired_error_message): util.MaskScoreWrapper(atari_env, [dict(x=(0, 1), y=(1, 0))]) - with pytest.raises(): + with pytest.raises(ValueError, match=desired_error_message): util.MaskScoreWrapper(atari_env, [dict(x=(1, 0), y=(0, 1))]) From 013d9425cbba4563ebf4bbee0cd23cac67312c8c Mon Sep 17 00:00:00 2001 From: Stewy Slocum Date: Fri, 28 Oct 2022 16:53:41 -0400 Subject: [PATCH 5/9] Fix test error due to incorrectly named unmasked envs --- tests/test_envs.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/tests/test_envs.py b/tests/test_envs.py index 28e77cc..af9ecf7 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -72,10 +72,12 @@ def test_atari_space_invaders(): def test_atari_unmasked_env_naming(): """Tests that all unmasked Atari envs have the appropriate name qualifier.""" - noncompliant_envs = [ - (_get_score_region(name) is None and "Unmasked" not in name) - for name in ATARI_ENVS - ] + noncompliant_envs = list( + filter( + lambda name: _get_score_region(name) is None and "Unmasked" not in name, + ATARI_ENVS, + ) + ) assert len(noncompliant_envs) == 0 @@ -103,11 +105,11 @@ def test_seed(self, env: gym.Env, env_name: str): if env_name in ATARI_ENVS: # these environments take a while for their non-determinism to show. slow_random_envs = [ - "seals/Bowling-v5", - "seals/Frogger-v5", - "seals/KingKong-v5", - "seals/Koolaid-v5", - "seals/NameThisGame-v5", + "seals/Bowling-Unmasked-v5", + "seals/Frogger-Unmasked-v5", + "seals/KingKong-Unmasked-v5", + "seals/Koolaid-Unmasked-v5", + "seals/NameThisGame-Unmasked-v5", ] rollout_len = 100 if env_name not in slow_random_envs else 400 num_seeds = 2 if env_name in ATARI_NO_FRAMESKIP_ENVS else 10 From 6cb74b4febc8ff5a038647e0f26ebc02297e13ff Mon Sep 17 00:00:00 2001 From: Adam Gleave Date: Fri, 28 Oct 2022 18:04:49 -0700 Subject: [PATCH 6/9] Fix lint in test_envs.py --- tests/test_envs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_envs.py b/tests/test_envs.py index af9ecf7..9b74d44 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -76,7 +76,7 @@ def test_atari_unmasked_env_naming(): filter( lambda name: _get_score_region(name) is None and "Unmasked" not in name, ATARI_ENVS, - ) + ), ) assert len(noncompliant_envs) == 0 From 15f85694f3fb8d1faa76c9f5afa649105eafa27c Mon Sep 17 00:00:00 2001 From: Stewy Slocum Date: Mon, 31 Oct 2022 14:27:37 -0400 Subject: [PATCH 7/9] Add test to make_atari_env to check for exception thrown on unavailable masked env --- tests/test_envs.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/tests/test_envs.py b/tests/test_envs.py index 9b74d44..23c9fdc 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -7,7 +7,7 @@ import pytest import seals # noqa: F401 required for env registration -from seals.atari import SCORE_REGIONS, _get_score_region, _seals_name +from seals.atari import SCORE_REGIONS, _get_score_region, _seals_name, make_atari_env from seals.testing import envs ENV_NAMES: List[str] = [ @@ -81,6 +81,16 @@ def test_atari_unmasked_env_naming(): assert len(noncompliant_envs) == 0 +def test_make_unsupported_masked_atari_env_throws_error(): + """Tests that making an unsupported masked Atari env throws an error.""" + match_str = ( + "Requested environment does not yet support masking. " + "See https://github.com/HumanCompatibleAI/seals/issues/61." + ) + with pytest.raises(ValueError, match=match_str): + make_atari_env("ALE/Bowling-v5", masked=True) + + def test_atari_masks_satisfy_spec(): """Tests that all Atari masks satisfy the spec.""" masks_satisfy_spec = [ From e60925e39226bf64c571faf8cb223e0ea7445d55 Mon Sep 17 00:00:00 2001 From: Stewy Slocum Date: Tue, 15 Nov 2022 17:35:47 -0500 Subject: [PATCH 8/9] Add type alias and dataclass for masked score region objects --- src/seals/atari.py | 32 +++++++++++++++++++------------- src/seals/util.py | 22 +++++++++++++++++----- tests/test_envs.py | 2 +- tests/test_util.py | 4 ++-- 4 files changed, 39 insertions(+), 21 deletions(-) diff --git a/src/seals/atari.py b/src/seals/atari.py index 190fc2e..a2d896c 100644 --- a/src/seals/atari.py +++ b/src/seals/atari.py @@ -1,29 +1,35 @@ """Adaptation of Atari environments for specification learning algorithms.""" -from typing import Dict, Iterable, List, Optional, Tuple +from typing import Dict, Iterable, Optional import gym -from seals.util import AutoResetWrapper, MaskScoreWrapper, get_gym_max_episode_steps +from seals.util import ( + AutoResetWrapper, + BoxRegion, + MaskedRegionSpecifier, + MaskScoreWrapper, + get_gym_max_episode_steps, +) -SCORE_REGIONS: Dict[str, List[Dict[str, Tuple[int, int]]]] = { +SCORE_REGIONS: Dict[str, MaskedRegionSpecifier] = { "BeamRider": [ - dict(x=(5, 20), y=(45, 120)), - dict(x=(28, 40), y=(15, 40)), + BoxRegion(x=(5, 20), y=(45, 120)), + BoxRegion(x=(28, 40), y=(15, 40)), ], - "Breakout": [dict(x=(0, 16), y=(35, 80))], + "Breakout": [BoxRegion(x=(0, 16), y=(35, 80))], "Enduro": [ - dict(x=(163, 173), y=(55, 110)), - dict(x=(177, 188), y=(68, 107)), + BoxRegion(x=(163, 173), y=(55, 110)), + BoxRegion(x=(177, 188), y=(68, 107)), ], - "Pong": [dict(x=(0, 24), y=(0, 160))], - "Qbert": [dict(x=(6, 15), y=(33, 71))], - "Seaquest": [dict(x=(7, 19), y=(80, 110))], - "SpaceInvaders": [dict(x=(10, 20), y=(0, 160))], + "Pong": [BoxRegion(x=(0, 24), y=(0, 160))], + "Qbert": [BoxRegion(x=(6, 15), y=(33, 71))], + "Seaquest": [BoxRegion(x=(7, 19), y=(80, 110))], + "SpaceInvaders": [BoxRegion(x=(10, 20), y=(0, 160))], } -def _get_score_region(atari_env_id: str) -> Optional[List[Dict[str, Tuple[int, int]]]]: +def _get_score_region(atari_env_id: str) -> Optional[MaskedRegionSpecifier]: basename = atari_env_id.split("/")[-1].split("-")[0] basename = basename.replace("NoFrameskip", "") return SCORE_REGIONS.get(basename) diff --git a/src/seals/util.py b/src/seals/util.py index d0a3b46..ac367f9 100644 --- a/src/seals/util.py +++ b/src/seals/util.py @@ -1,6 +1,7 @@ """Miscellaneous utilities.""" -from typing import Dict, List, Optional, Sequence, Tuple, Union +from dataclasses import dataclass +from typing import List, Optional, Sequence, Tuple, Union import gym import numpy as np @@ -23,10 +24,21 @@ def step(self, action): return obs, rew, False, info +@dataclass +class BoxRegion: + """A rectangular region dataclass used by MaskScoreWrapper.""" + + x: Tuple + y: Tuple + + +MaskedRegionSpecifier = List[BoxRegion] + + class MaskScoreWrapper(gym.Wrapper): """Mask a list of box-shaped regions in the observation to hide reward info. - Intended for environments whose observations are raw pixels (like atari + Intended for environments whose observations are raw pixels (like Atari environments). Used to mask regions of the observation that include information that could be used to infer the reward, like the game score or enemy ship count. """ @@ -34,7 +46,7 @@ class MaskScoreWrapper(gym.Wrapper): def __init__( self, env: gym.Env, - score_regions: List[Dict[str, Tuple[int, int]]], + score_regions: MaskedRegionSpecifier, fill_value: Union[float, Sequence[float]] = 0, ): """Builds MaskScoreWrapper. @@ -55,9 +67,9 @@ def __init__( self.mask = np.ones(env.observation_space.shape, dtype=bool) for r in score_regions: - if r["x"][0] >= r["x"][1] or r["y"][0] >= r["y"][1]: + if r.x[0] >= r.x[1] or r.y[0] >= r.y[1]: raise ValueError('Invalid region: "x" and "y" must be increasing.') - self.mask[r["x"][0] : r["x"][1], r["y"][0] : r["y"][1]] = 0 + self.mask[r.x[0] : r.x[1], r.y[0] : r.y[1]] = 0 def _mask_obs(self, obs): return np.where(self.mask, obs, self.fill_value) diff --git a/tests/test_envs.py b/tests/test_envs.py index 23c9fdc..b2e86d4 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -94,7 +94,7 @@ def test_make_unsupported_masked_atari_env_throws_error(): def test_atari_masks_satisfy_spec(): """Tests that all Atari masks satisfy the spec.""" masks_satisfy_spec = [ - mask["x"][0] < mask["x"][1] and mask["y"][0] < mask["y"][1] + mask.x[0] < mask.x[1] and mask.y[0] < mask.y[1] for env_regions in SCORE_REGIONS.values() for mask in env_regions ] diff --git a/tests/test_util.py b/tests/test_util.py index f7071a7..5829ebb 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -14,9 +14,9 @@ def test_mask_score_wrapper_enforces_spec(): atari_env = gym.make(GYM_ATARI_ENV_SPECS[0].id) desired_error_message = 'Invalid region: "x" and "y" must be increasing.' with pytest.raises(ValueError, match=desired_error_message): - util.MaskScoreWrapper(atari_env, [dict(x=(0, 1), y=(1, 0))]) + util.MaskScoreWrapper(atari_env, [util.BoxRegion(x=(0, 1), y=(1, 0))]) with pytest.raises(ValueError, match=desired_error_message): - util.MaskScoreWrapper(atari_env, [dict(x=(1, 0), y=(0, 1))]) + util.MaskScoreWrapper(atari_env, [util.BoxRegion(x=(1, 0), y=(0, 1))]) def test_sample_distribution(): From f48083535624b933e24d64556215bceb14804584 Mon Sep 17 00:00:00 2001 From: Adam Gleave Date: Mon, 21 Nov 2022 17:48:57 -0800 Subject: [PATCH 9/9] Pin pyglet version to workaround bug --- setup.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index eca507e..342e5e4 100644 --- a/setup.py +++ b/setup.py @@ -115,7 +115,10 @@ def get_readme() -> str: "pytest-xdist", "pytype", "stable-baselines3>=0.9.0", - "pyglet>=1.4.0", + # TODO(adam): remove pyglet pin once Gym upgraded to >0.21 + # Workaround for https://github.com/openai/gym/issues/2986 + # Discussed in https://github.com/HumanCompatibleAI/imitation/pull/603 + "pyglet==1.5.27", "setuptools_scm~=7.0.5", *ATARI_REQUIRE, ]