Skip to content

Commit

Permalink
Add atari environments (#57)
Browse files Browse the repository at this point in the history
* Add auto-resetting atari environments

* Improve documentation, reduce testing time for atari envs

* Add atari dependencies for testing

* Simplify boolean check

Co-authored-by: Juan Rocamonde <[email protected]>

* Demand result of env.seed be a list or tuple

* Add documentation for why atari is treated differently in tests

* clarify why we check that env.seed returns a list or tuple

* Parametrize test_seed

* Explain why Bowling and NameThisGame are different

* Tidy code, clarify README

* Pull atari envs from gym registry, rather than hard-coding

* Fix formatting of test_envs

* Fix spelling of MuJoCo

Co-authored-by: Adam Gleave <[email protected]>

* Fix Atari capitalization

Co-authored-by: Adam Gleave <[email protected]>

* Remove TODO

* Simplify checking that filter isn't empty

* Add documentation to helper methods

* Don't unnecessarily use list where tuple will do

Co-authored-by: Adam Gleave <[email protected]>

* Move atari registration logic to atari.py

* Add type annotations

* Fix capitalization of 'Atari'

Co-authored-by: Juan Rocamonde <[email protected]>
Co-authored-by: Adam Gleave <[email protected]>
  • Loading branch information
3 people authored Sep 6, 2022
1 parent f3a2404 commit 5dfffa5
Show file tree
Hide file tree
Showing 6 changed files with 161 additions and 18 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ There are two types of environments in *seals*:

- **Diagnostic Tasks** which test individual facets of algorithm performance in isolation.
- **Renovated Environments**, adaptations of widely-used benchmarks such as MuJoCo continuous
control tasks to be suitable for specification learning benchmarks. In particular, we remove
any side-channel sources of reward information.
control tasks and Atari games to be suitable for specification learning benchmarks. In particular,
we remove any side-channel sources of reward information from MuJoCo tasks, and give Atari games constant-length episodes (although most Atari environments have observations that include the score).

*seals* is under active development and we intend to add more categories of tasks soon.

Expand Down
9 changes: 9 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ def get_readme() -> str:
return f.read()


ATARI_REQUIRE = [
"opencv-python",
"ale-py==0.7.4",
"pillow",
"autorom[accept-rom-license]~=0.4.2",
]
TESTS_REQUIRE = [
# remove pin once https://github.com/nedbat/coveragepy/issues/881 fixed
"black",
Expand All @@ -50,6 +56,7 @@ def get_readme() -> str:
"pytype",
"stable-baselines3>=0.9.0",
"pyglet>=1.4.0",
*ATARI_REQUIRE,
]
DOCS_REQUIRE = [
"sphinx",
Expand All @@ -58,6 +65,7 @@ def get_readme() -> str:
"sphinxcontrib-napoleon",
]


setup(
name="seals",
version=get_version(),
Expand All @@ -79,6 +87,7 @@ def get_readme() -> str:
# We'd like to specify `gym[mujoco]`, but this is a no-op when Gym is already
# installed. See https://github.com/pypa/pip/issues/4957 for issue.
"mujoco": ["mujoco_py>=1.50, <2.0", "imageio"],
"atari": ATARI_REQUIRE,
},
url="https://github.com/HumanCompatibleAI/benchmark-environments",
license="MIT",
Expand Down
7 changes: 6 additions & 1 deletion src/seals/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import gym

from seals import util
from seals import atari, util
import seals.diagnostics # noqa: F401
from seals.version import VERSION as __version__ # noqa: F401

Expand All @@ -28,3 +28,8 @@
entry_point=f"seals.mujoco:{env_base}Env",
max_episode_steps=util.get_gym_max_episode_steps(f"{env_base}-v3"),
)

# Atari

GYM_ATARI_ENV_SPECS = list(filter(atari._supported_atari_env, gym.envs.registry.all()))
atari.register_atari_envs(GYM_ATARI_ENV_SPECS)
56 changes: 56 additions & 0 deletions src/seals/atari.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
"""Adaptation of Atari environments for specification learning algorithms."""

from typing import Iterable

import gym

from seals.util import AutoResetWrapper, get_gym_max_episode_steps


def fixed_length_atari(atari_env_id: str) -> gym.Env:
"""Fixed-length variant of a given Atari environment."""
return AutoResetWrapper(gym.make(atari_env_id))


def _not_ram_or_det(env_id: str) -> bool:
"""Checks a gym Atari environment isn't deterministic or using RAM observations."""
slash_separated = env_id.split("/")
# environment name should look like "ALE/Amidar-v5" or "Amidar-ramNoFrameskip-v4"
assert len(slash_separated) in (1, 2)
after_slash = slash_separated[-1]
hyphen_separated = after_slash.split("-")
assert len(hyphen_separated) > 1
not_ram = not ("ram" in hyphen_separated[1])
not_deterministic = not ("Deterministic" in env_id)
return not_ram and not_deterministic


def _supported_atari_env(gym_spec: gym.envs.registration.EnvSpec) -> bool:
"""Checks if a gym Atari environment is one of the ones we will support."""
is_atari = gym_spec.entry_point == "gym.envs.atari:AtariEnv"
v5_and_plain = gym_spec.id.endswith("-v5") and not ("NoFrameskip" in gym_spec.id)
v4_and_no_frameskip = gym_spec.id.endswith("-v4") and "NoFrameskip" in gym_spec.id
return (
is_atari
and _not_ram_or_det(gym_spec.id)
and (v5_and_plain or v4_and_no_frameskip)
)


def _seals_name(gym_spec: gym.envs.registration.EnvSpec) -> str:
"""Makes a Gym ID for an Atari environment in the seals namespace."""
slash_separated = gym_spec.id.split("/")
return "seals/" + slash_separated[-1]


def register_atari_envs(
gym_atari_env_specs: Iterable[gym.envs.registration.EnvSpec],
) -> None:
"""Register wrapped gym Atari environments."""
for gym_spec in gym_atari_env_specs:
gym.register(
id=_seals_name(gym_spec),
entry_point="seals.atari:fixed_length_atari",
max_episode_steps=get_gym_max_episode_steps(gym_spec.id),
kwargs=dict(atari_env_id=gym_spec.id),
)
32 changes: 22 additions & 10 deletions src/seals/testing/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,13 @@ def has_same_observations(rollout_a: Rollout, rollout_b: Rollout) -> bool:
return True


def test_seed(env: gym.Env, env_name: str, deterministic_envs: Iterable[str]) -> None:
def test_seed(
env: gym.Env,
env_name: str,
deterministic_envs: Iterable[str],
rollout_len: int = 10,
num_seeds: int = 100,
) -> None:
"""Tests environment seeding.
If non-deterministic, different seeds should produce different transitions.
Expand All @@ -147,11 +153,11 @@ def test_seed(env: gym.Env, env_name: str, deterministic_envs: Iterable[str]) ->
AssertionError if test fails.
"""
env.action_space.seed(0)
actions = [env.action_space.sample() for _ in range(10)]

actions = [env.action_space.sample() for _ in range(rollout_len)]
# With the same seed, should always get the same result
seeds = env.seed(42)
assert isinstance(seeds, list)
# output of env.seed should be a list, but atari environments return a tuple.
assert isinstance(seeds, (list, tuple))
assert len(seeds) > 0
rollout_a = get_rollout(env, actions)

Expand All @@ -164,15 +170,17 @@ def test_seed(env: gym.Env, env_name: str, deterministic_envs: Iterable[str]) ->
# eventually get a different result. For deterministic environments, all
# seeds should produce the same starting state.
def different_seeds_same_rollout(seed1, seed2):
new_actions = [env.action_space.sample() for _ in range(10)]
new_actions = [env.action_space.sample() for _ in range(rollout_len)]
env.seed(seed1)
new_rollout_1 = get_rollout(env, new_actions)
env.seed(seed2)
new_rollout_2 = get_rollout(env, new_actions)
return has_same_observations(new_rollout_1, new_rollout_2)

is_deterministic = matches_list(env_name, deterministic_envs)
same_obs = all(different_seeds_same_rollout(seed, seed + 1) for seed in range(100))
same_obs = all(
different_seeds_same_rollout(seed, seed + 1) for seed in range(num_seeds)
)
assert same_obs == is_deterministic


Expand Down Expand Up @@ -202,7 +210,8 @@ def _is_mujoco_env(env: gym.Env) -> bool:
def test_rollout_schema(
env: gym.Env,
steps_after_done: int = 10,
max_steps: int = 10000,
max_steps: int = 10_000,
check_episode_ends: bool = True,
) -> None:
"""Check custom environments have correct types on `step` and `reset`.
Expand All @@ -212,6 +221,8 @@ def test_rollout_schema(
episode termination. This is an abuse of the Gym API, but we would like the
environments to handle this case gracefully.
max_steps: Test fails if we do not get `done` after this many timesteps.
check_episode_ends: Check that episode ends after `max_steps` steps, and that
steps taken after `done` is true are of the correct type.
Raises:
AssertionError if test fails.
Expand All @@ -225,10 +236,11 @@ def test_rollout_schema(
if done:
break

assert done is True, "did not get to end of episode"
if check_episode_ends:
assert done, "did not get to end of episode"

for _ in range(steps_after_done):
_sample_and_check(env, obs_space)
for _ in range(steps_after_done):
_sample_and_check(env, obs_space)


def test_premature_step(env: gym.Env, skip_fn, raises_fn) -> None:
Expand Down
71 changes: 66 additions & 5 deletions tests/test_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest

import seals # noqa: F401 required for env registration
from seals.atari import _seals_name
from seals.testing import envs

ENV_NAMES: List[str] = [
Expand All @@ -15,6 +16,7 @@
if env_spec.id.startswith("seals/")
]


DETERMINISTIC_ENVS: List[str] = [
"seals/EarlyTermPos-v0",
"seals/EarlyTermNeg-v0",
Expand All @@ -23,25 +25,84 @@
"seals/InitShiftTest-v0",
]

ATARI_ENVS: List[str] = [
_seals_name(gym_spec) for gym_spec in seals.GYM_ATARI_ENV_SPECS
]

ATARI_V5_ENVS: List[str] = list(filter(lambda name: name.endswith("-v5"), ATARI_ENVS))
ATARI_NO_FRAMESKIP_ENVS: List[str] = list(
filter(lambda name: name.endswith("-v4"), ATARI_ENVS),
)

DETERMINISTIC_ENVS += ATARI_NO_FRAMESKIP_ENVS


env = pytest.fixture(envs.make_env_fixture(skip_fn=pytest.skip))


def test_some_atari_envs():
"""Tests if we succeeded in finding any Atari envs."""
assert len(seals.GYM_ATARI_ENV_SPECS) > 0


def test_atari_space_invaders():
"""Tests if there's an Atari environment called space invaders."""
space_invader_environments = list(
filter(
lambda name: "SpaceInvaders" in name,
ATARI_ENVS,
),
)
assert len(space_invader_environments) > 0


@pytest.mark.parametrize("env_name", ENV_NAMES)
class TestEnvs:
"""Battery of simple tests for environments."""

def test_seed(self, env: gym.Env, env_name: str):
"""Tests environment seeding."""
envs.test_seed(env, env_name, DETERMINISTIC_ENVS)
"""Tests environment seeding.
Deterministic Atari environments are run with fewer seeds to minimize the number
of resets done in this test suite, since Atari resets take a long time and there
are many Atari environments.
"""
if env_name in ATARI_ENVS:
# these environments take a while for their non-determinism to show.
slow_random_envs = [
"seals/Bowling-v5",
"seals/Frogger-v5",
"seals/KingKong-v5",
"seals/Koolaid-v5",
"seals/NameThisGame-v5",
]
rollout_len = 100 if env_name not in slow_random_envs else 400
num_seeds = 2 if env_name in ATARI_NO_FRAMESKIP_ENVS else 10
envs.test_seed(
env,
env_name,
DETERMINISTIC_ENVS,
rollout_len=rollout_len,
num_seeds=num_seeds,
)
else:
envs.test_seed(env, env_name, DETERMINISTIC_ENVS)

def test_premature_step(self, env: gym.Env):
"""Tests if step() before reset() raises error."""
envs.test_premature_step(env, skip_fn=pytest.skip, raises_fn=pytest.raises)

def test_rollout_schema(self, env: gym.Env):
"""Tests if environments have correct types on `step()` and `reset()`."""
envs.test_rollout_schema(env)
def test_rollout_schema(self, env: gym.Env, env_name: str):
"""Tests if environments have correct types on `step()` and `reset()`.
Atari environments have a very long episode length (~100k observations), so in
the interest of time we do not run them to the end of their episodes or check
the return time of `env.step` after the end of the episode.
"""
if env_name in ATARI_ENVS:
envs.test_rollout_schema(env, max_steps=1_000, check_episode_ends=False)
else:
envs.test_rollout_schema(env)

def test_render(self, env: gym.Env):
"""Tests `render()` supports modes specified in environment metadata."""
Expand Down

0 comments on commit 5dfffa5

Please sign in to comment.