Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add score masking to seven atari environments #62

Merged
merged 10 commits into from
Nov 22, 2022
68 changes: 56 additions & 12 deletions src/seals/atari.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,48 @@
"""Adaptation of Atari environments for specification learning algorithms."""

from typing import Iterable
from typing import Dict, Iterable, List, Optional, Tuple

import gym

from seals.util import AutoResetWrapper, get_gym_max_episode_steps
from seals.util import AutoResetWrapper, MaskScoreWrapper, get_gym_max_episode_steps

SCORE_REGIONS: Dict[str, List[Dict[str, Tuple[int, int]]]] = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You use List[Dict[str, Tuple[int, int]]] in three places in your code -- consider defining it as a type? Like:

MaskedRegionSpecifier = List[Dict[str, Tuple[int, int]]]

I'd also consider using a named tuple instead of dict to enforce that x and y are both present.

"BeamRider": [
dict(x=(5, 20), y=(45, 120)),
dict(x=(28, 40), y=(15, 40)),
],
"Breakout": [dict(x=(0, 16), y=(35, 80))],
"Enduro": [
dict(x=(163, 173), y=(55, 110)),
dict(x=(177, 188), y=(68, 107)),
],
"Pong": [dict(x=(0, 24), y=(0, 160))],
"Qbert": [dict(x=(6, 15), y=(33, 71))],
"Seaquest": [dict(x=(7, 19), y=(80, 110))],
"SpaceInvaders": [dict(x=(10, 20), y=(0, 160))],
}

def fixed_length_atari(atari_env_id: str) -> gym.Env:
"""Fixed-length variant of a given Atari environment."""
return AutoResetWrapper(gym.make(atari_env_id))

def _get_score_region(atari_env_id: str) -> Optional[List[Dict[str, Tuple[int, int]]]]:
basename = atari_env_id.split("/")[-1].split("-")[0]
basename = basename.replace("NoFrameskip", "")
return SCORE_REGIONS.get(basename)


def make_atari_env(atari_env_id: str, masked: bool) -> gym.Env:
"""Fixed-length, optionally masked-score variant of a given Atari environment."""
env = AutoResetWrapper(gym.make(atari_env_id))

if masked:
score_region = _get_score_region(atari_env_id)
if score_region is None:
raise ValueError(
"Requested environment does not yet support masking. "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, thanks for adding the informative error message :)

"See https://github.com/HumanCompatibleAI/seals/issues/61.",
)
env = MaskScoreWrapper(env, score_region)

return env


def _not_ram_or_det(env_id: str) -> bool:
Expand Down Expand Up @@ -37,20 +70,31 @@ def _supported_atari_env(gym_spec: gym.envs.registration.EnvSpec) -> bool:
)


def _seals_name(gym_spec: gym.envs.registration.EnvSpec) -> str:
def _seals_name(gym_spec: gym.envs.registration.EnvSpec, masked: bool) -> str:
"""Makes a Gym ID for an Atari environment in the seals namespace."""
slash_separated = gym_spec.id.split("/")
return "seals/" + slash_separated[-1]
name = "seals/" + slash_separated[-1]

if not masked:
last_hyphen_idx = name.rfind("-v")
name = name[:last_hyphen_idx] + "-Unmasked" + name[last_hyphen_idx:]
return name


def register_atari_envs(
gym_atari_env_specs: Iterable[gym.envs.registration.EnvSpec],
) -> None:
"""Register wrapped gym Atari environments."""
for gym_spec in gym_atari_env_specs:
"""Register masked and unmasked wrapped gym Atari environments."""

def register_gym(masked):
gym.register(
id=_seals_name(gym_spec),
entry_point="seals.atari:fixed_length_atari",
id=_seals_name(gym_spec, masked=masked),
entry_point="seals.atari:make_atari_env",
max_episode_steps=get_gym_max_episode_steps(gym_spec.id),
kwargs=dict(atari_env_id=gym_spec.id),
kwargs=dict(atari_env_id=gym_spec.id, masked=masked),
)

for gym_spec in gym_atari_env_specs:
register_gym(masked=False)
if _get_score_region(gym_spec.id) is not None:
AdamGleave marked this conversation as resolved.
Show resolved Hide resolved
register_gym(masked=True)
52 changes: 51 additions & 1 deletion src/seals/util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Miscellaneous utilities."""

from typing import Optional, Tuple
from typing import Dict, List, Optional, Sequence, Tuple, Union

import gym
import numpy as np
Expand All @@ -23,6 +23,56 @@ def step(self, action):
return obs, rew, False, info


class MaskScoreWrapper(gym.Wrapper):
Rocamonde marked this conversation as resolved.
Show resolved Hide resolved
"""Mask a list of box-shaped regions in the observation to hide reward info.

Intended for environments whose observations are raw pixels (like atari
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Intended for environments whose observations are raw pixels (like atari
Intended for environments whose observations are raw pixels (like Atari

environments). Used to mask regions of the observation that include information
that could be used to infer the reward, like the game score or enemy ship count.
"""

def __init__(
self,
env: gym.Env,
score_regions: List[Dict[str, Tuple[int, int]]],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(If you did define a type alias this file would be the natural place to do it.)

fill_value: Union[float, Sequence[float]] = 0,
):
"""Builds MaskScoreWrapper.

Args:
env: The environment to wrap.
score_regions: A list of box-shaped regions to mask, each denoted by
a dictionary `{"x": (x0, x1), "y": (y0, y1)}`, where `x0 < x1`
and `y0 < y1`.
fill_value: The fill_value for the masked region. By default is black.
Can support RGB colors by being a sequence of values [r, g, b].

Raises:
ValueError: If a score region does not conform to the spec.
"""
super().__init__(env)
self.fill_value = np.array(fill_value, env.observation_space.dtype)
Rocamonde marked this conversation as resolved.
Show resolved Hide resolved

self.mask = np.ones(env.observation_space.shape, dtype=bool)
for r in score_regions:
if r["x"][0] >= r["x"][1] or r["y"][0] >= r["y"][1]:
raise ValueError('Invalid region: "x" and "y" must be increasing.')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice input validation!

self.mask[r["x"][0] : r["x"][1], r["y"][0] : r["y"][1]] = 0

def _mask_obs(self, obs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Thanks for adding this. Code looks cleaner now IMO.

return np.where(self.mask, obs, self.fill_value)

def step(self, action):
"""Returns (obs, rew, done, info) with masked obs."""
obs, rew, done, info = self.env.step(action)
return self._mask_obs(obs), rew, done, info

def reset(self, **kwargs):
"""Returns masked reset observation."""
obs = self.env.reset(**kwargs)
return self._mask_obs(obs)


class ObsCastWrapper(gym.Wrapper):
"""Cast observations to specified dtype.

Expand Down
47 changes: 40 additions & 7 deletions tests/test_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest

import seals # noqa: F401 required for env registration
from seals.atari import _seals_name
from seals.atari import SCORE_REGIONS, _get_score_region, _seals_name
from seals.testing import envs

ENV_NAMES: List[str] = [
Expand All @@ -25,9 +25,15 @@
"seals/InitShiftTest-v0",
]

ATARI_ENVS: List[str] = [
_seals_name(gym_spec) for gym_spec in seals.GYM_ATARI_ENV_SPECS
UNMASKED_ATARI_ENVS: List[str] = [
_seals_name(gym_spec, masked=False) for gym_spec in seals.GYM_ATARI_ENV_SPECS
stewy33 marked this conversation as resolved.
Show resolved Hide resolved
]
MASKED_ATARI_ENVS: List[str] = [
_seals_name(gym_spec, masked=True)
for gym_spec in seals.GYM_ATARI_ENV_SPECS
if _get_score_region(gym_spec.id) is not None
]
ATARI_ENVS = UNMASKED_ATARI_ENVS + MASKED_ATARI_ENVS

ATARI_V5_ENVS: List[str] = list(filter(lambda name: name.endswith("-v5"), ATARI_ENVS))
ATARI_NO_FRAMESKIP_ENVS: List[str] = list(
Expand All @@ -46,14 +52,41 @@ def test_some_atari_envs():


def test_atari_space_invaders():
"""Tests if there's an Atari environment called space invaders."""
space_invader_environments = list(
"""Tests for masked and unmasked Atari space invaders environments."""
masked_space_invader_environments = list(
filter(
lambda name: "SpaceInvaders" in name and "Unmasked" not in name,
ATARI_ENVS,
),
)
assert len(masked_space_invader_environments) > 0

unmasked_space_invader_environments = list(
filter(
lambda name: "SpaceInvaders" in name,
lambda name: "SpaceInvaders" in name and "Unmasked" in name,
ATARI_ENVS,
),
)
assert len(space_invader_environments) > 0
assert len(unmasked_space_invader_environments) > 0


def test_atari_unmasked_env_naming():
"""Tests that all unmasked Atari envs have the appropriate name qualifier."""
noncompliant_envs = [
(_get_score_region(name) is None and "Unmasked" not in name)
for name in ATARI_ENVS
]
assert len(noncompliant_envs) == 0


def test_atari_masks_satisfy_spec():
"""Tests that all Atari masks satisfy the spec."""
masks_satisfy_spec = [
mask["x"][0] < mask["x"][1] and mask["y"][0] < mask["y"][1]
for env_regions in SCORE_REGIONS.values()
for mask in env_regions
]
assert all(masks_satisfy_spec)


@pytest.mark.parametrize("env_name", ENV_NAMES)
Expand Down
13 changes: 12 additions & 1 deletion tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,20 @@

import collections

import gym
import numpy as np
import pytest

from seals import util
from seals import GYM_ATARI_ENV_SPECS, util


def test_mask_score_wrapper_enforces_spec():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be nice to add a test that actually checks MaskScoreWrapper masks the observations -- e.g. you could have a dummy environment that returns all-ones, a dummy mask config, and then just check that region (and only that region) is zero.

I won't insist on it though, the MaskScoreWrapper implementation is simple and readable already so is unlikely to have a bug, and you've already done a lot of work in this PR!

"""Test that MaskScoreWrapper enforces the spec."""
atari_env = gym.make(GYM_ATARI_ENV_SPECS[0].id)
with pytest.raises():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you specify the error that is raised and use the match option to match the error message?

util.MaskScoreWrapper(atari_env, [dict(x=(0, 1), y=(1, 0))])
with pytest.raises():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above (error type + error message match)

util.MaskScoreWrapper(atari_env, [dict(x=(1, 0), y=(0, 1))])


def test_sample_distribution():
Expand Down