-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add score masking to seven atari environments #62
Changes from 3 commits
6956efd
a97839f
b2f7d96
31af01a
aef96e5
013d942
6cb74b4
15f8569
e60925e
f480835
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,15 +1,48 @@ | ||
"""Adaptation of Atari environments for specification learning algorithms.""" | ||
|
||
from typing import Iterable | ||
from typing import Dict, Iterable, List, Optional, Tuple | ||
|
||
import gym | ||
|
||
from seals.util import AutoResetWrapper, get_gym_max_episode_steps | ||
from seals.util import AutoResetWrapper, MaskScoreWrapper, get_gym_max_episode_steps | ||
|
||
SCORE_REGIONS: Dict[str, List[Dict[str, Tuple[int, int]]]] = { | ||
"BeamRider": [ | ||
dict(x=(5, 20), y=(45, 120)), | ||
dict(x=(28, 40), y=(15, 40)), | ||
], | ||
"Breakout": [dict(x=(0, 16), y=(35, 80))], | ||
"Enduro": [ | ||
dict(x=(163, 173), y=(55, 110)), | ||
dict(x=(177, 188), y=(68, 107)), | ||
], | ||
"Pong": [dict(x=(0, 24), y=(0, 160))], | ||
"Qbert": [dict(x=(6, 15), y=(33, 71))], | ||
"Seaquest": [dict(x=(7, 19), y=(80, 110))], | ||
"SpaceInvaders": [dict(x=(10, 20), y=(0, 160))], | ||
} | ||
|
||
def fixed_length_atari(atari_env_id: str) -> gym.Env: | ||
"""Fixed-length variant of a given Atari environment.""" | ||
return AutoResetWrapper(gym.make(atari_env_id)) | ||
|
||
def _get_score_region(atari_env_id: str) -> Optional[List[Dict[str, Tuple[int, int]]]]: | ||
basename = atari_env_id.split("/")[-1].split("-")[0] | ||
basename = basename.replace("NoFrameskip", "") | ||
return SCORE_REGIONS.get(basename) | ||
|
||
|
||
def make_atari_env(atari_env_id: str, masked: bool) -> gym.Env: | ||
"""Fixed-length, optionally masked-score variant of a given Atari environment.""" | ||
env = AutoResetWrapper(gym.make(atari_env_id)) | ||
|
||
if masked: | ||
score_region = _get_score_region(atari_env_id) | ||
if score_region is None: | ||
raise ValueError( | ||
"Requested environment does not yet support masking. " | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice, thanks for adding the informative error message :) |
||
"See https://github.com/HumanCompatibleAI/seals/issues/61.", | ||
) | ||
env = MaskScoreWrapper(env, score_region) | ||
|
||
return env | ||
|
||
|
||
def _not_ram_or_det(env_id: str) -> bool: | ||
|
@@ -37,20 +70,31 @@ def _supported_atari_env(gym_spec: gym.envs.registration.EnvSpec) -> bool: | |
) | ||
|
||
|
||
def _seals_name(gym_spec: gym.envs.registration.EnvSpec) -> str: | ||
def _seals_name(gym_spec: gym.envs.registration.EnvSpec, masked: bool) -> str: | ||
"""Makes a Gym ID for an Atari environment in the seals namespace.""" | ||
slash_separated = gym_spec.id.split("/") | ||
return "seals/" + slash_separated[-1] | ||
name = "seals/" + slash_separated[-1] | ||
|
||
if not masked: | ||
last_hyphen_idx = name.rfind("-v") | ||
name = name[:last_hyphen_idx] + "-Unmasked" + name[last_hyphen_idx:] | ||
return name | ||
|
||
|
||
def register_atari_envs( | ||
gym_atari_env_specs: Iterable[gym.envs.registration.EnvSpec], | ||
) -> None: | ||
"""Register wrapped gym Atari environments.""" | ||
for gym_spec in gym_atari_env_specs: | ||
"""Register masked and unmasked wrapped gym Atari environments.""" | ||
|
||
def register_gym(masked): | ||
gym.register( | ||
id=_seals_name(gym_spec), | ||
entry_point="seals.atari:fixed_length_atari", | ||
id=_seals_name(gym_spec, masked=masked), | ||
entry_point="seals.atari:make_atari_env", | ||
max_episode_steps=get_gym_max_episode_steps(gym_spec.id), | ||
kwargs=dict(atari_env_id=gym_spec.id), | ||
kwargs=dict(atari_env_id=gym_spec.id, masked=masked), | ||
) | ||
|
||
for gym_spec in gym_atari_env_specs: | ||
register_gym(masked=False) | ||
if _get_score_region(gym_spec.id) is not None: | ||
AdamGleave marked this conversation as resolved.
Show resolved
Hide resolved
|
||
register_gym(masked=True) |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -1,6 +1,6 @@ | ||||||
"""Miscellaneous utilities.""" | ||||||
|
||||||
from typing import Optional, Tuple | ||||||
from typing import Dict, List, Optional, Sequence, Tuple, Union | ||||||
|
||||||
import gym | ||||||
import numpy as np | ||||||
|
@@ -23,6 +23,56 @@ def step(self, action): | |||||
return obs, rew, False, info | ||||||
|
||||||
|
||||||
class MaskScoreWrapper(gym.Wrapper): | ||||||
Rocamonde marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
"""Mask a list of box-shaped regions in the observation to hide reward info. | ||||||
|
||||||
Intended for environments whose observations are raw pixels (like atari | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
environments). Used to mask regions of the observation that include information | ||||||
that could be used to infer the reward, like the game score or enemy ship count. | ||||||
""" | ||||||
|
||||||
def __init__( | ||||||
self, | ||||||
env: gym.Env, | ||||||
score_regions: List[Dict[str, Tuple[int, int]]], | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (If you did define a type alias this file would be the natural place to do it.) |
||||||
fill_value: Union[float, Sequence[float]] = 0, | ||||||
): | ||||||
"""Builds MaskScoreWrapper. | ||||||
|
||||||
Args: | ||||||
env: The environment to wrap. | ||||||
score_regions: A list of box-shaped regions to mask, each denoted by | ||||||
a dictionary `{"x": (x0, x1), "y": (y0, y1)}`, where `x0 < x1` | ||||||
and `y0 < y1`. | ||||||
fill_value: The fill_value for the masked region. By default is black. | ||||||
Can support RGB colors by being a sequence of values [r, g, b]. | ||||||
|
||||||
Raises: | ||||||
ValueError: If a score region does not conform to the spec. | ||||||
""" | ||||||
super().__init__(env) | ||||||
self.fill_value = np.array(fill_value, env.observation_space.dtype) | ||||||
Rocamonde marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
|
||||||
self.mask = np.ones(env.observation_space.shape, dtype=bool) | ||||||
for r in score_regions: | ||||||
if r["x"][0] >= r["x"][1] or r["y"][0] >= r["y"][1]: | ||||||
raise ValueError('Invalid region: "x" and "y" must be increasing.') | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice input validation! |
||||||
self.mask[r["x"][0] : r["x"][1], r["y"][0] : r["y"][1]] = 0 | ||||||
|
||||||
def _mask_obs(self, obs): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice! Thanks for adding this. Code looks cleaner now IMO. |
||||||
return np.where(self.mask, obs, self.fill_value) | ||||||
|
||||||
def step(self, action): | ||||||
"""Returns (obs, rew, done, info) with masked obs.""" | ||||||
obs, rew, done, info = self.env.step(action) | ||||||
return self._mask_obs(obs), rew, done, info | ||||||
|
||||||
def reset(self, **kwargs): | ||||||
"""Returns masked reset observation.""" | ||||||
obs = self.env.reset(**kwargs) | ||||||
return self._mask_obs(obs) | ||||||
|
||||||
|
||||||
class ObsCastWrapper(gym.Wrapper): | ||||||
"""Cast observations to specified dtype. | ||||||
|
||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,9 +2,20 @@ | |
|
||
import collections | ||
|
||
import gym | ||
import numpy as np | ||
import pytest | ||
|
||
from seals import util | ||
from seals import GYM_ATARI_ENV_SPECS, util | ||
|
||
|
||
def test_mask_score_wrapper_enforces_spec(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It might be nice to add a test that actually checks I won't insist on it though, the |
||
"""Test that MaskScoreWrapper enforces the spec.""" | ||
atari_env = gym.make(GYM_ATARI_ENV_SPECS[0].id) | ||
with pytest.raises(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you specify the error that is raised and use the |
||
util.MaskScoreWrapper(atari_env, [dict(x=(0, 1), y=(1, 0))]) | ||
with pytest.raises(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above (error type + error message match) |
||
util.MaskScoreWrapper(atari_env, [dict(x=(1, 0), y=(0, 1))]) | ||
|
||
|
||
def test_sample_distribution(): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You use
List[Dict[str, Tuple[int, int]]]
in three places in your code -- consider defining it as a type? Like:I'd also consider using a named tuple instead of
dict
to enforce thatx
andy
are both present.