-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add score masking to seven atari environments #62
Conversation
@Rocamonde, do you mind reviewing this PR?
It's an interesting question. I'd guess number of lives is often decision relevant for the agent -- e.g. it wants to be more risk averse when only one life left. So I lean against masking it, though I agree it introduces a confounder. Having an option to mask it (that we could leave off by default) is probably the best thing to do. But, masking the score is a definitive improvement over not masking it so this shouldn't hold up the PR. |
Codecov Report
@@ Coverage Diff @@
## master #62 +/- ##
=========================================
Coverage 100.00% 100.00%
=========================================
Files 26 26
Lines 982 1047 +65
=========================================
+ Hits 982 1047 +65
📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks so much for taking the time to submit this PR! Overall I agree with the design choices and implementation, and seems like a pretty useful feature to add to seals. There are some major comments that I would like to see addressed:
- Masking
obs_from_state
instead of environment interaction methods - Whether we want to force all registered environments to be masked (even if users could still manually registered non-masked environments)
- Whether we should have wrappers always inherit from our base class for type safety.
By default, I would favor the solutions I left in the comments, but happy to hear your thoughts on them. Since these are fairly major design choices for the project, maybe @AdamGleave wants to weigh in on these too?
These suggestions seem to be based on the assumption that the Atari environments implement the I guess we could make the state be the non-masked observation, and mask it in the observation? That could work but I don't see the benefit. |
Hi, I just incorporated some of the suggested changes! I added the option of having unmasked atari environments. The naming convention is now |
src/seals/atari.py
Outdated
if score_region is None: | ||
raise ValueError( | ||
"Requested environment does not yet support masking. " | ||
+ "See https://github.com/HumanCompatibleAI/seals/issues/61.", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The + is unnecessary (https://docs.python.org/3/reference/lexical_analysis.html#string-literal-concatenation). It actually introduces runtime overhead, even though this is largely irrelevant and it's more of a standard style choice.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, done.
src/seals/atari.py
Outdated
name = "seals/" + slash_separated[-1] | ||
|
||
if not masked: | ||
last_hyphen_idx = name.rfind("-") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are we confident all environments will have a {name}-v{num}
format? It's been the case everywhere that I've seen, but this would preclude us from registering environments without this format, and that's probably at least worth documenting.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, in the _supported_atari_env
method, we already only support an Atari environment if it ends with "-v4" or "-v5". So I think this is ok for now.
src/seals/util.py
Outdated
@@ -51,17 +52,21 @@ def __init__( | |||
|
|||
self.mask = np.ones(env.observation_space.shape, dtype=bool) | |||
for r in score_regions: | |||
self.mask[r["x0"] : r["x1"], r["y0"] : r["y1"]] = 0 | |||
assert r["x"][0] < r["x"][1] and r["y"][0] < r["y"][1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this is a public method (users could create their own wrapper beyond our internal usage for seals-defined environments) you should raise a ValueError instead (and add a corresponding test). Passing in the wrong values is very much a possibility.
assert
is to be used when something should always behave in a certain way by virtue of the purported logic of the program. This allows catching logical bugs and reassuring code checkers. This would be fine if only our own internally defined masks could ever be used. However, when something is part of the public API and therefore contingent on user input, we cannot really assert that a user won't pass the wrong value. See https://stackoverflow.com/questions/17530627/python-assertion-style#:~:text=The%20assert%20statement%20should%20only,user%20input%20or%20the%20environment. and https://wiki.python.org/moin/UsingAssertionsEffectively
What you can do, however, is have tests that assert that our internally defined masks verify this. That, plus a with raises
test on the MaskScoreWrapper API should be enough to thoroughly test this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the explanation!
assert r["x"][0] < r["x"][1] and r["y"][0] < r["y"][1] | ||
self.mask[r["x"][0] : r["x"][1], r["y"][0] : r["y"][1]] = 0 | ||
|
||
def _mask_obs(self, obs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice! Thanks for adding this. Code looks cleaner now IMO.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, only a minor fix on the test cases. Please fix before merging. The tests say pending, you might want to re-trigger the pipeline.
tests/test_util.py
Outdated
def test_mask_score_wrapper_enforces_spec(): | ||
"""Test that MaskScoreWrapper enforces the spec.""" | ||
atari_env = gym.make(GYM_ATARI_ENV_SPECS[0].id) | ||
with pytest.raises(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you specify the error that is raised and use the match
option to match the error message?
tests/test_util.py
Outdated
atari_env = gym.make(GYM_ATARI_ENV_SPECS[0].id) | ||
with pytest.raises(): | ||
util.MaskScoreWrapper(atari_env, [dict(x=(0, 1), y=(1, 0))]) | ||
with pytest.raises(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above (error type + error message match)
@Rocamonde stewy33 doesn't have permission to trigger them (our tests don't run on fork), I have done this now with |
Hi, I'm new to contributing to open-source projects, so I'm wondering if
|
Thanks, just re-triggered this after the most recent changes. @AdamGleave what's gonna happen when we do this for multiple open PRs? (apparently the way it works is by pushing to the
Don't worry, I've just taken care of that. Normally it's automatic, but since your PR is from a fork, it's not.
IIRC the desirable option is the only one enabled when you press the merge button on the PR. AFAIK we always squash and merge (Adam correct me if this is wrong). Just click the button on the PR once the tests pass and GitHub should automatically do this and close the PR. |
Seems like you have failing tests due to two small mistakes.
noncompliant_envs = [
(_get_score_region(name) is None and "Unmasked" not in name)
for name in ATARI_ENVS
]
assert len(noncompliant_envs) == 0 what you actually get is a list of bools with all elements False. So you can either noncompliant_envs = [
name
if (_get_score_region(name) is None and "Unmasked" not in name)
for name in ATARI_ENVS
] so you actually apply the filter to the list comprehension (I think this is probably what you intended to do), or is_each_env_noncompliant = [
(_get_score_region(name) is None and "Unmasked" not in name)
for name in ATARI_ENVS
]
assert not any(is_each_env_noncompliant) which is also correct but less intuitive. I probably prefer the first TBH. Then for the second error, we manually hardcoded the names of some envs not to check because they take a while to show determinism. But since now we also have unmasked and masked versions, your unmasked envs are getting checked again and the tests are failing. # these environments take a while for their non-determinism to show.
slow_random_envs = [
"seals/Bowling-v5",
"seals/Frogger-v5",
"seals/KingKong-v5",
"seals/Koolaid-v5",
"seals/NameThisGame-v5",
] I would just manually add the other tests to this list for now. Once you're at it, check that we have no hardcoded lists of environments anywhere else that are not being updated. |
Could also change the test to check if the prefix of the env name matches and skip accordingly. |
Hi, I added another test to hopefully past code coverage. Could one of you re-trigger tests? If they pass, I think we're ready to merge. |
@AdamGleave codecov is being annoying, I think it thinks the CI failed and is not reporting coverage, but I checked on the website and all looks good. Can you override this / retrigger codecov? the PR LGTM. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Rocamonde codecov seems to have fixed itself unless you did something?
@stewy33 thanks for all your work on this PR. Got around to reviewing it. Pretty much looks good, a couple of minor suggestions. After that should be ready to merge :)
src/seals/atari.py
Outdated
|
||
SCORE_REGIONS: Dict[str, List[Dict[str, Tuple[int, int]]]] = { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You use List[Dict[str, Tuple[int, int]]]
in three places in your code -- consider defining it as a type? Like:
MaskedRegionSpecifier = List[Dict[str, Tuple[int, int]]]
I'd also consider using a named tuple instead of dict
to enforce that x
and y
are both present.
score_region = _get_score_region(atari_env_id) | ||
if score_region is None: | ||
raise ValueError( | ||
"Requested environment does not yet support masking. " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice, thanks for adding the informative error message :)
src/seals/util.py
Outdated
def __init__( | ||
self, | ||
env: gym.Env, | ||
score_regions: List[Dict[str, Tuple[int, int]]], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(If you did define a type alias this file would be the natural place to do it.)
src/seals/util.py
Outdated
class MaskScoreWrapper(gym.Wrapper): | ||
"""Mask a list of box-shaped regions in the observation to hide reward info. | ||
|
||
Intended for environments whose observations are raw pixels (like atari |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Intended for environments whose observations are raw pixels (like atari | |
Intended for environments whose observations are raw pixels (like Atari |
self.mask = np.ones(env.observation_space.shape, dtype=bool) | ||
for r in score_regions: | ||
if r["x"][0] >= r["x"][1] or r["y"][0] >= r["y"][1]: | ||
raise ValueError('Invalid region: "x" and "y" must be increasing.') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice input validation!
from seals import GYM_ATARI_ENV_SPECS, util | ||
|
||
|
||
def test_mask_score_wrapper_enforces_spec(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be nice to add a test that actually checks MaskScoreWrapper
masks the observations -- e.g. you could have a dummy environment that returns all-ones, a dummy mask config, and then just check that region (and only that region) is zero.
I won't insist on it though, the MaskScoreWrapper
implementation is simple and readable already so is unlikely to have a bug, and you've already done a lot of work in this PR!
Hi, could someone trigger the tests one last time before merge? Just incorporated the desired minor changes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Thanks for your patience with this! Also I've changed the CI set up so tests should hopefully run OK on forked repos in futuer.
Fixes #61
Added score masking to the seven atari environments from the RLHF paper. I used a black background to cover the score, to cover enemy ship count for BeamRider, and cover the speedometer for Enduro.
Note that the number of lives is unmasked. This matches the original implementation in the RLHF paper. However, it seems that episode boundaries could be inferred from there. What do we think about this choice?