Skip to content

Commit

Permalink
Add score masking to seven atari environments (#62)
Browse files Browse the repository at this point in the history
* Add score masking to seven atari environments

* Add option of masked or unmasked atari envs

* add final tests and cosmetic changes for masked score atari environments

* Add pytest match option to MaskScoreWrapper tests

* Fix test error due to incorrectly named unmasked envs

* Fix lint in test_envs.py

* Add test to make_atari_env to check for exception thrown on unavailable masked env

* Add type alias and dataclass for masked score region objects

* Pin pyglet version to workaround bug

Co-authored-by: Adam Gleave <[email protected]>
  • Loading branch information
stewy33 and AdamGleave authored Nov 22, 2022
1 parent 956fbb4 commit dc7a695
Show file tree
Hide file tree
Showing 5 changed files with 199 additions and 27 deletions.
5 changes: 4 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,10 @@ def get_readme() -> str:
"pytest-xdist",
"pytype",
"stable-baselines3>=0.9.0",
"pyglet>=1.4.0",
# TODO(adam): remove pyglet pin once Gym upgraded to >0.21
# Workaround for https://github.com/openai/gym/issues/2986
# Discussed in https://github.com/HumanCompatibleAI/imitation/pull/603
"pyglet==1.5.27",
"setuptools_scm~=7.0.5",
*ATARI_REQUIRE,
]
Expand Down
74 changes: 62 additions & 12 deletions src/seals/atari.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,54 @@
"""Adaptation of Atari environments for specification learning algorithms."""

from typing import Iterable
from typing import Dict, Iterable, Optional

import gym

from seals.util import AutoResetWrapper, get_gym_max_episode_steps
from seals.util import (
AutoResetWrapper,
BoxRegion,
MaskedRegionSpecifier,
MaskScoreWrapper,
get_gym_max_episode_steps,
)

SCORE_REGIONS: Dict[str, MaskedRegionSpecifier] = {
"BeamRider": [
BoxRegion(x=(5, 20), y=(45, 120)),
BoxRegion(x=(28, 40), y=(15, 40)),
],
"Breakout": [BoxRegion(x=(0, 16), y=(35, 80))],
"Enduro": [
BoxRegion(x=(163, 173), y=(55, 110)),
BoxRegion(x=(177, 188), y=(68, 107)),
],
"Pong": [BoxRegion(x=(0, 24), y=(0, 160))],
"Qbert": [BoxRegion(x=(6, 15), y=(33, 71))],
"Seaquest": [BoxRegion(x=(7, 19), y=(80, 110))],
"SpaceInvaders": [BoxRegion(x=(10, 20), y=(0, 160))],
}

def fixed_length_atari(atari_env_id: str) -> gym.Env:
"""Fixed-length variant of a given Atari environment."""
return AutoResetWrapper(gym.make(atari_env_id))

def _get_score_region(atari_env_id: str) -> Optional[MaskedRegionSpecifier]:
basename = atari_env_id.split("/")[-1].split("-")[0]
basename = basename.replace("NoFrameskip", "")
return SCORE_REGIONS.get(basename)


def make_atari_env(atari_env_id: str, masked: bool) -> gym.Env:
"""Fixed-length, optionally masked-score variant of a given Atari environment."""
env = AutoResetWrapper(gym.make(atari_env_id))

if masked:
score_region = _get_score_region(atari_env_id)
if score_region is None:
raise ValueError(
"Requested environment does not yet support masking. "
"See https://github.com/HumanCompatibleAI/seals/issues/61.",
)
env = MaskScoreWrapper(env, score_region)

return env


def _not_ram_or_det(env_id: str) -> bool:
Expand Down Expand Up @@ -37,20 +76,31 @@ def _supported_atari_env(gym_spec: gym.envs.registration.EnvSpec) -> bool:
)


def _seals_name(gym_spec: gym.envs.registration.EnvSpec) -> str:
def _seals_name(gym_spec: gym.envs.registration.EnvSpec, masked: bool) -> str:
"""Makes a Gym ID for an Atari environment in the seals namespace."""
slash_separated = gym_spec.id.split("/")
return "seals/" + slash_separated[-1]
name = "seals/" + slash_separated[-1]

if not masked:
last_hyphen_idx = name.rfind("-v")
name = name[:last_hyphen_idx] + "-Unmasked" + name[last_hyphen_idx:]
return name


def register_atari_envs(
gym_atari_env_specs: Iterable[gym.envs.registration.EnvSpec],
) -> None:
"""Register wrapped gym Atari environments."""
for gym_spec in gym_atari_env_specs:
"""Register masked and unmasked wrapped gym Atari environments."""

def register_gym(masked):
gym.register(
id=_seals_name(gym_spec),
entry_point="seals.atari:fixed_length_atari",
id=_seals_name(gym_spec, masked=masked),
entry_point="seals.atari:make_atari_env",
max_episode_steps=get_gym_max_episode_steps(gym_spec.id),
kwargs=dict(atari_env_id=gym_spec.id),
kwargs=dict(atari_env_id=gym_spec.id, masked=masked),
)

for gym_spec in gym_atari_env_specs:
register_gym(masked=False)
if _get_score_region(gym_spec.id) is not None:
register_gym(masked=True)
64 changes: 63 additions & 1 deletion src/seals/util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Miscellaneous utilities."""

from typing import Optional, Tuple
from dataclasses import dataclass
from typing import List, Optional, Sequence, Tuple, Union

import gym
import numpy as np
Expand All @@ -23,6 +24,67 @@ def step(self, action):
return obs, rew, False, info


@dataclass
class BoxRegion:
"""A rectangular region dataclass used by MaskScoreWrapper."""

x: Tuple
y: Tuple


MaskedRegionSpecifier = List[BoxRegion]


class MaskScoreWrapper(gym.Wrapper):
"""Mask a list of box-shaped regions in the observation to hide reward info.
Intended for environments whose observations are raw pixels (like Atari
environments). Used to mask regions of the observation that include information
that could be used to infer the reward, like the game score or enemy ship count.
"""

def __init__(
self,
env: gym.Env,
score_regions: MaskedRegionSpecifier,
fill_value: Union[float, Sequence[float]] = 0,
):
"""Builds MaskScoreWrapper.
Args:
env: The environment to wrap.
score_regions: A list of box-shaped regions to mask, each denoted by
a dictionary `{"x": (x0, x1), "y": (y0, y1)}`, where `x0 < x1`
and `y0 < y1`.
fill_value: The fill_value for the masked region. By default is black.
Can support RGB colors by being a sequence of values [r, g, b].
Raises:
ValueError: If a score region does not conform to the spec.
"""
super().__init__(env)
self.fill_value = np.array(fill_value, env.observation_space.dtype)

self.mask = np.ones(env.observation_space.shape, dtype=bool)
for r in score_regions:
if r.x[0] >= r.x[1] or r.y[0] >= r.y[1]:
raise ValueError('Invalid region: "x" and "y" must be increasing.')
self.mask[r.x[0] : r.x[1], r.y[0] : r.y[1]] = 0

def _mask_obs(self, obs):
return np.where(self.mask, obs, self.fill_value)

def step(self, action):
"""Returns (obs, rew, done, info) with masked obs."""
obs, rew, done, info = self.env.step(action)
return self._mask_obs(obs), rew, done, info

def reset(self, **kwargs):
"""Returns masked reset observation."""
obs = self.env.reset(**kwargs)
return self._mask_obs(obs)


class ObsCastWrapper(gym.Wrapper):
"""Cast observations to specified dtype.
Expand Down
69 changes: 57 additions & 12 deletions tests/test_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest

import seals # noqa: F401 required for env registration
from seals.atari import _seals_name
from seals.atari import SCORE_REGIONS, _get_score_region, _seals_name, make_atari_env
from seals.testing import envs

ENV_NAMES: List[str] = [
Expand All @@ -25,9 +25,15 @@
"seals/InitShiftTest-v0",
]

ATARI_ENVS: List[str] = [
_seals_name(gym_spec) for gym_spec in seals.GYM_ATARI_ENV_SPECS
UNMASKED_ATARI_ENVS: List[str] = [
_seals_name(gym_spec, masked=False) for gym_spec in seals.GYM_ATARI_ENV_SPECS
]
MASKED_ATARI_ENVS: List[str] = [
_seals_name(gym_spec, masked=True)
for gym_spec in seals.GYM_ATARI_ENV_SPECS
if _get_score_region(gym_spec.id) is not None
]
ATARI_ENVS = UNMASKED_ATARI_ENVS + MASKED_ATARI_ENVS

ATARI_V5_ENVS: List[str] = list(filter(lambda name: name.endswith("-v5"), ATARI_ENVS))
ATARI_NO_FRAMESKIP_ENVS: List[str] = list(
Expand All @@ -46,14 +52,53 @@ def test_some_atari_envs():


def test_atari_space_invaders():
"""Tests if there's an Atari environment called space invaders."""
space_invader_environments = list(
"""Tests for masked and unmasked Atari space invaders environments."""
masked_space_invader_environments = list(
filter(
lambda name: "SpaceInvaders" in name,
lambda name: "SpaceInvaders" in name and "Unmasked" not in name,
ATARI_ENVS,
),
)
assert len(space_invader_environments) > 0
assert len(masked_space_invader_environments) > 0

unmasked_space_invader_environments = list(
filter(
lambda name: "SpaceInvaders" in name and "Unmasked" in name,
ATARI_ENVS,
),
)
assert len(unmasked_space_invader_environments) > 0


def test_atari_unmasked_env_naming():
"""Tests that all unmasked Atari envs have the appropriate name qualifier."""
noncompliant_envs = list(
filter(
lambda name: _get_score_region(name) is None and "Unmasked" not in name,
ATARI_ENVS,
),
)
assert len(noncompliant_envs) == 0


def test_make_unsupported_masked_atari_env_throws_error():
"""Tests that making an unsupported masked Atari env throws an error."""
match_str = (
"Requested environment does not yet support masking. "
"See https://github.com/HumanCompatibleAI/seals/issues/61."
)
with pytest.raises(ValueError, match=match_str):
make_atari_env("ALE/Bowling-v5", masked=True)


def test_atari_masks_satisfy_spec():
"""Tests that all Atari masks satisfy the spec."""
masks_satisfy_spec = [
mask.x[0] < mask.x[1] and mask.y[0] < mask.y[1]
for env_regions in SCORE_REGIONS.values()
for mask in env_regions
]
assert all(masks_satisfy_spec)


@pytest.mark.parametrize("env_name", ENV_NAMES)
Expand All @@ -70,11 +115,11 @@ def test_seed(self, env: gym.Env, env_name: str):
if env_name in ATARI_ENVS:
# these environments take a while for their non-determinism to show.
slow_random_envs = [
"seals/Bowling-v5",
"seals/Frogger-v5",
"seals/KingKong-v5",
"seals/Koolaid-v5",
"seals/NameThisGame-v5",
"seals/Bowling-Unmasked-v5",
"seals/Frogger-Unmasked-v5",
"seals/KingKong-Unmasked-v5",
"seals/Koolaid-Unmasked-v5",
"seals/NameThisGame-Unmasked-v5",
]
rollout_len = 100 if env_name not in slow_random_envs else 400
num_seeds = 2 if env_name in ATARI_NO_FRAMESKIP_ENVS else 10
Expand Down
14 changes: 13 additions & 1 deletion tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,21 @@

import collections

import gym
import numpy as np
import pytest

from seals import util
from seals import GYM_ATARI_ENV_SPECS, util


def test_mask_score_wrapper_enforces_spec():
"""Test that MaskScoreWrapper enforces the spec."""
atari_env = gym.make(GYM_ATARI_ENV_SPECS[0].id)
desired_error_message = 'Invalid region: "x" and "y" must be increasing.'
with pytest.raises(ValueError, match=desired_error_message):
util.MaskScoreWrapper(atari_env, [util.BoxRegion(x=(0, 1), y=(1, 0))])
with pytest.raises(ValueError, match=desired_error_message):
util.MaskScoreWrapper(atari_env, [util.BoxRegion(x=(1, 0), y=(0, 1))])


def test_sample_distribution():
Expand Down

0 comments on commit dc7a695

Please sign in to comment.