Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add score masking to seven atari environments #62

Merged
merged 10 commits into from
Nov 22, 2022
21 changes: 10 additions & 11 deletions src/seals/atari.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def make_atari_env(atari_env_id: str, masked: bool) -> gym.Env:
if score_region is None:
raise ValueError(
"Requested environment does not yet support masking. "
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, thanks for adding the informative error message :)

+ "See https://github.com/HumanCompatibleAI/seals/issues/61.",
"See https://github.com/HumanCompatibleAI/seals/issues/61.",
)
env = MaskScoreWrapper(env, score_region)

Expand Down Expand Up @@ -76,7 +76,7 @@ def _seals_name(gym_spec: gym.envs.registration.EnvSpec, masked: bool) -> str:
name = "seals/" + slash_separated[-1]

if not masked:
last_hyphen_idx = name.rfind("-")
last_hyphen_idx = name.rfind("-v")
name = name[:last_hyphen_idx] + "-Unmasked" + name[last_hyphen_idx:]
return name

Expand All @@ -85,17 +85,16 @@ def register_atari_envs(
gym_atari_env_specs: Iterable[gym.envs.registration.EnvSpec],
) -> None:
"""Register masked and unmasked wrapped gym Atari environments."""
for gym_spec in gym_atari_env_specs:

def register_gym(masked):
gym.register(
id=_seals_name(gym_spec, masked=False),
id=_seals_name(gym_spec, masked=masked),
entry_point="seals.atari:make_atari_env",
max_episode_steps=get_gym_max_episode_steps(gym_spec.id),
kwargs=dict(atari_env_id=gym_spec.id, masked=False),
kwargs=dict(atari_env_id=gym_spec.id, masked=masked),
)

for gym_spec in gym_atari_env_specs:
register_gym(masked=False)
if _get_score_region(gym_spec.id) is not None:
AdamGleave marked this conversation as resolved.
Show resolved Hide resolved
gym.register(
id=_seals_name(gym_spec, masked=True),
entry_point="seals.atari:make_atari_env",
max_episode_steps=get_gym_max_episode_steps(gym_spec.id),
kwargs=dict(atari_env_id=gym_spec.id, masked=True),
)
register_gym(masked=True)
6 changes: 5 additions & 1 deletion src/seals/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,17 @@ def __init__(
and `y0 < y1`.
fill_value: The fill_value for the masked region. By default is black.
Can support RGB colors by being a sequence of values [r, g, b].

Raises:
ValueError: If a score region does not conform to the spec.
"""
super().__init__(env)
self.fill_value = np.array(fill_value, env.observation_space.dtype)
Rocamonde marked this conversation as resolved.
Show resolved Hide resolved

self.mask = np.ones(env.observation_space.shape, dtype=bool)
for r in score_regions:
assert r["x"][0] < r["x"][1] and r["y"][0] < r["y"][1]
if r["x"][0] >= r["x"][1] or r["y"][0] >= r["y"][1]:
raise ValueError('Invalid region: "x" and "y" must be increasing.')
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice input validation!

self.mask[r["x"][0] : r["x"][1], r["y"][0] : r["y"][1]] = 0

def _mask_obs(self, obs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Thanks for adding this. Code looks cleaner now IMO.

Expand Down
18 changes: 15 additions & 3 deletions tests/test_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytest

import seals # noqa: F401 required for env registration
from seals.atari import _get_score_region, _seals_name
from seals.atari import SCORE_REGIONS, _get_score_region, _seals_name
from seals.testing import envs

ENV_NAMES: List[str] = [
Expand All @@ -25,13 +25,15 @@
"seals/InitShiftTest-v0",
]

ATARI_ENVS: List[str] = [
UNMASKED_ATARI_ENVS: List[str] = [
_seals_name(gym_spec, masked=False) for gym_spec in seals.GYM_ATARI_ENV_SPECS
stewy33 marked this conversation as resolved.
Show resolved Hide resolved
] + [
]
MASKED_ATARI_ENVS: List[str] = [
_seals_name(gym_spec, masked=True)
for gym_spec in seals.GYM_ATARI_ENV_SPECS
if _get_score_region(gym_spec.id) is not None
]
ATARI_ENVS = UNMASKED_ATARI_ENVS + MASKED_ATARI_ENVS

ATARI_V5_ENVS: List[str] = list(filter(lambda name: name.endswith("-v5"), ATARI_ENVS))
ATARI_NO_FRAMESKIP_ENVS: List[str] = list(
Expand Down Expand Up @@ -77,6 +79,16 @@ def test_atari_unmasked_env_naming():
assert len(noncompliant_envs) == 0


def test_atari_masks_satisfy_spec():
"""Tests that all Atari masks satisfy the spec."""
masks_satisfy_spec = [
mask["x"][0] < mask["x"][1] and mask["y"][0] < mask["y"][1]
for env_regions in SCORE_REGIONS.values()
for mask in env_regions
]
assert all(masks_satisfy_spec)


@pytest.mark.parametrize("env_name", ENV_NAMES)
class TestEnvs:
"""Battery of simple tests for environments."""
Expand Down
13 changes: 12 additions & 1 deletion tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,20 @@

import collections

import gym
import numpy as np
import pytest

from seals import util
from seals import GYM_ATARI_ENV_SPECS, util


def test_mask_score_wrapper_enforces_spec():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be nice to add a test that actually checks MaskScoreWrapper masks the observations -- e.g. you could have a dummy environment that returns all-ones, a dummy mask config, and then just check that region (and only that region) is zero.

I won't insist on it though, the MaskScoreWrapper implementation is simple and readable already so is unlikely to have a bug, and you've already done a lot of work in this PR!

"""Test that MaskScoreWrapper enforces the spec."""
atari_env = gym.make(GYM_ATARI_ENV_SPECS[0].id)
with pytest.raises():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you specify the error that is raised and use the match option to match the error message?

util.MaskScoreWrapper(atari_env, [dict(x=(0, 1), y=(1, 0))])
with pytest.raises():
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above (error type + error message match)

util.MaskScoreWrapper(atari_env, [dict(x=(1, 0), y=(0, 1))])


def test_sample_distribution():
Expand Down