-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add score masking to seven atari environments #62
Changes from 1 commit
6956efd
a97839f
b2f7d96
31af01a
aef96e5
013d942
6cb74b4
15f8569
e60925e
f480835
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -46,13 +46,17 @@ def __init__( | |
and `y0 < y1`. | ||
fill_value: The fill_value for the masked region. By default is black. | ||
Can support RGB colors by being a sequence of values [r, g, b]. | ||
|
||
Raises: | ||
ValueError: If a score region does not conform to the spec. | ||
""" | ||
super().__init__(env) | ||
self.fill_value = np.array(fill_value, env.observation_space.dtype) | ||
Rocamonde marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
self.mask = np.ones(env.observation_space.shape, dtype=bool) | ||
for r in score_regions: | ||
assert r["x"][0] < r["x"][1] and r["y"][0] < r["y"][1] | ||
if r["x"][0] >= r["x"][1] or r["y"][0] >= r["y"][1]: | ||
raise ValueError('Invalid region: "x" and "y" must be increasing.') | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice input validation! |
||
self.mask[r["x"][0] : r["x"][1], r["y"][0] : r["y"][1]] = 0 | ||
|
||
def _mask_obs(self, obs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice! Thanks for adding this. Code looks cleaner now IMO. |
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,9 +2,20 @@ | |
|
||
import collections | ||
|
||
import gym | ||
import numpy as np | ||
import pytest | ||
|
||
from seals import util | ||
from seals import GYM_ATARI_ENV_SPECS, util | ||
|
||
|
||
def test_mask_score_wrapper_enforces_spec(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It might be nice to add a test that actually checks I won't insist on it though, the |
||
"""Test that MaskScoreWrapper enforces the spec.""" | ||
atari_env = gym.make(GYM_ATARI_ENV_SPECS[0].id) | ||
with pytest.raises(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you specify the error that is raised and use the |
||
util.MaskScoreWrapper(atari_env, [dict(x=(0, 1), y=(1, 0))]) | ||
with pytest.raises(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above (error type + error message match) |
||
util.MaskScoreWrapper(atari_env, [dict(x=(1, 0), y=(0, 1))]) | ||
|
||
|
||
def test_sample_distribution(): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, thanks for adding the informative error message :)