diff --git a/README.md b/README.md index 71a5b39..7af6766 100644 --- a/README.md +++ b/README.md @@ -17,8 +17,8 @@ There are two types of environments in *seals*: - **Diagnostic Tasks** which test individual facets of algorithm performance in isolation. - **Renovated Environments**, adaptations of widely-used benchmarks such as MuJoCo continuous - control tasks to be suitable for specification learning benchmarks. In particular, we remove - any side-channel sources of reward information. + control tasks and Atari games to be suitable for specification learning benchmarks. In particular, + we remove any side-channel sources of reward information from MuJoCo tasks, and give Atari games constant-length episodes (although most Atari environments have observations that include the score). *seals* is under active development and we intend to add more categories of tasks soon. diff --git a/setup.py b/setup.py index a4754f7..7b32d85 100644 --- a/setup.py +++ b/setup.py @@ -27,6 +27,12 @@ def get_readme() -> str: return f.read() +ATARI_REQUIRE = [ + "opencv-python", + "ale-py==0.7.4", + "pillow", + "autorom[accept-rom-license]~=0.4.2", +] TESTS_REQUIRE = [ # remove pin once https://github.com/nedbat/coveragepy/issues/881 fixed "black", @@ -50,6 +56,7 @@ def get_readme() -> str: "pytype", "stable-baselines3>=0.9.0", "pyglet>=1.4.0", + *ATARI_REQUIRE, ] DOCS_REQUIRE = [ "sphinx", @@ -58,6 +65,7 @@ def get_readme() -> str: "sphinxcontrib-napoleon", ] + setup( name="seals", version=get_version(), @@ -79,6 +87,7 @@ def get_readme() -> str: # We'd like to specify `gym[mujoco]`, but this is a no-op when Gym is already # installed. See https://github.com/pypa/pip/issues/4957 for issue. "mujoco": ["mujoco_py>=1.50, <2.0", "imageio"], + "atari": ATARI_REQUIRE, }, url="https://github.com/HumanCompatibleAI/benchmark-environments", license="MIT", diff --git a/src/seals/__init__.py b/src/seals/__init__.py index 7694c6a..c4ef46e 100644 --- a/src/seals/__init__.py +++ b/src/seals/__init__.py @@ -2,7 +2,7 @@ import gym -from seals import util +from seals import atari, util import seals.diagnostics # noqa: F401 from seals.version import VERSION as __version__ # noqa: F401 @@ -28,3 +28,8 @@ entry_point=f"seals.mujoco:{env_base}Env", max_episode_steps=util.get_gym_max_episode_steps(f"{env_base}-v3"), ) + +# Atari + +GYM_ATARI_ENV_SPECS = list(filter(atari._supported_atari_env, gym.envs.registry.all())) +atari.register_atari_envs(GYM_ATARI_ENV_SPECS) diff --git a/src/seals/atari.py b/src/seals/atari.py new file mode 100644 index 0000000..ff01024 --- /dev/null +++ b/src/seals/atari.py @@ -0,0 +1,56 @@ +"""Adaptation of Atari environments for specification learning algorithms.""" + +from typing import Iterable + +import gym + +from seals.util import AutoResetWrapper, get_gym_max_episode_steps + + +def fixed_length_atari(atari_env_id: str) -> gym.Env: + """Fixed-length variant of a given Atari environment.""" + return AutoResetWrapper(gym.make(atari_env_id)) + + +def _not_ram_or_det(env_id: str) -> bool: + """Checks a gym Atari environment isn't deterministic or using RAM observations.""" + slash_separated = env_id.split("/") + # environment name should look like "ALE/Amidar-v5" or "Amidar-ramNoFrameskip-v4" + assert len(slash_separated) in (1, 2) + after_slash = slash_separated[-1] + hyphen_separated = after_slash.split("-") + assert len(hyphen_separated) > 1 + not_ram = not ("ram" in hyphen_separated[1]) + not_deterministic = not ("Deterministic" in env_id) + return not_ram and not_deterministic + + +def _supported_atari_env(gym_spec: gym.envs.registration.EnvSpec) -> bool: + """Checks if a gym Atari environment is one of the ones we will support.""" + is_atari = gym_spec.entry_point == "gym.envs.atari:AtariEnv" + v5_and_plain = gym_spec.id.endswith("-v5") and not ("NoFrameskip" in gym_spec.id) + v4_and_no_frameskip = gym_spec.id.endswith("-v4") and "NoFrameskip" in gym_spec.id + return ( + is_atari + and _not_ram_or_det(gym_spec.id) + and (v5_and_plain or v4_and_no_frameskip) + ) + + +def _seals_name(gym_spec: gym.envs.registration.EnvSpec) -> str: + """Makes a Gym ID for an Atari environment in the seals namespace.""" + slash_separated = gym_spec.id.split("/") + return "seals/" + slash_separated[-1] + + +def register_atari_envs( + gym_atari_env_specs: Iterable[gym.envs.registration.EnvSpec], +) -> None: + """Register wrapped gym Atari environments.""" + for gym_spec in gym_atari_env_specs: + gym.register( + id=_seals_name(gym_spec), + entry_point="seals.atari:fixed_length_atari", + max_episode_steps=get_gym_max_episode_steps(gym_spec.id), + kwargs=dict(atari_env_id=gym_spec.id), + ) diff --git a/src/seals/testing/envs.py b/src/seals/testing/envs.py index 387b4de..a927a50 100644 --- a/src/seals/testing/envs.py +++ b/src/seals/testing/envs.py @@ -137,7 +137,13 @@ def has_same_observations(rollout_a: Rollout, rollout_b: Rollout) -> bool: return True -def test_seed(env: gym.Env, env_name: str, deterministic_envs: Iterable[str]) -> None: +def test_seed( + env: gym.Env, + env_name: str, + deterministic_envs: Iterable[str], + rollout_len: int = 10, + num_seeds: int = 100, +) -> None: """Tests environment seeding. If non-deterministic, different seeds should produce different transitions. @@ -147,11 +153,11 @@ def test_seed(env: gym.Env, env_name: str, deterministic_envs: Iterable[str]) -> AssertionError if test fails. """ env.action_space.seed(0) - actions = [env.action_space.sample() for _ in range(10)] - + actions = [env.action_space.sample() for _ in range(rollout_len)] # With the same seed, should always get the same result seeds = env.seed(42) - assert isinstance(seeds, list) + # output of env.seed should be a list, but atari environments return a tuple. + assert isinstance(seeds, (list, tuple)) assert len(seeds) > 0 rollout_a = get_rollout(env, actions) @@ -164,7 +170,7 @@ def test_seed(env: gym.Env, env_name: str, deterministic_envs: Iterable[str]) -> # eventually get a different result. For deterministic environments, all # seeds should produce the same starting state. def different_seeds_same_rollout(seed1, seed2): - new_actions = [env.action_space.sample() for _ in range(10)] + new_actions = [env.action_space.sample() for _ in range(rollout_len)] env.seed(seed1) new_rollout_1 = get_rollout(env, new_actions) env.seed(seed2) @@ -172,7 +178,9 @@ def different_seeds_same_rollout(seed1, seed2): return has_same_observations(new_rollout_1, new_rollout_2) is_deterministic = matches_list(env_name, deterministic_envs) - same_obs = all(different_seeds_same_rollout(seed, seed + 1) for seed in range(100)) + same_obs = all( + different_seeds_same_rollout(seed, seed + 1) for seed in range(num_seeds) + ) assert same_obs == is_deterministic @@ -202,7 +210,8 @@ def _is_mujoco_env(env: gym.Env) -> bool: def test_rollout_schema( env: gym.Env, steps_after_done: int = 10, - max_steps: int = 10000, + max_steps: int = 10_000, + check_episode_ends: bool = True, ) -> None: """Check custom environments have correct types on `step` and `reset`. @@ -212,6 +221,8 @@ def test_rollout_schema( episode termination. This is an abuse of the Gym API, but we would like the environments to handle this case gracefully. max_steps: Test fails if we do not get `done` after this many timesteps. + check_episode_ends: Check that episode ends after `max_steps` steps, and that + steps taken after `done` is true are of the correct type. Raises: AssertionError if test fails. @@ -225,10 +236,11 @@ def test_rollout_schema( if done: break - assert done is True, "did not get to end of episode" + if check_episode_ends: + assert done, "did not get to end of episode" - for _ in range(steps_after_done): - _sample_and_check(env, obs_space) + for _ in range(steps_after_done): + _sample_and_check(env, obs_space) def test_premature_step(env: gym.Env, skip_fn, raises_fn) -> None: diff --git a/tests/test_envs.py b/tests/test_envs.py index 4e8c69b..c33e176 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -7,6 +7,7 @@ import pytest import seals # noqa: F401 required for env registration +from seals.atari import _seals_name from seals.testing import envs ENV_NAMES: List[str] = [ @@ -15,6 +16,7 @@ if env_spec.id.startswith("seals/") ] + DETERMINISTIC_ENVS: List[str] = [ "seals/EarlyTermPos-v0", "seals/EarlyTermNeg-v0", @@ -23,25 +25,84 @@ "seals/InitShiftTest-v0", ] +ATARI_ENVS: List[str] = [ + _seals_name(gym_spec) for gym_spec in seals.GYM_ATARI_ENV_SPECS +] + +ATARI_V5_ENVS: List[str] = list(filter(lambda name: name.endswith("-v5"), ATARI_ENVS)) +ATARI_NO_FRAMESKIP_ENVS: List[str] = list( + filter(lambda name: name.endswith("-v4"), ATARI_ENVS), +) + +DETERMINISTIC_ENVS += ATARI_NO_FRAMESKIP_ENVS + env = pytest.fixture(envs.make_env_fixture(skip_fn=pytest.skip)) +def test_some_atari_envs(): + """Tests if we succeeded in finding any Atari envs.""" + assert len(seals.GYM_ATARI_ENV_SPECS) > 0 + + +def test_atari_space_invaders(): + """Tests if there's an Atari environment called space invaders.""" + space_invader_environments = list( + filter( + lambda name: "SpaceInvaders" in name, + ATARI_ENVS, + ), + ) + assert len(space_invader_environments) > 0 + + @pytest.mark.parametrize("env_name", ENV_NAMES) class TestEnvs: """Battery of simple tests for environments.""" def test_seed(self, env: gym.Env, env_name: str): - """Tests environment seeding.""" - envs.test_seed(env, env_name, DETERMINISTIC_ENVS) + """Tests environment seeding. + + Deterministic Atari environments are run with fewer seeds to minimize the number + of resets done in this test suite, since Atari resets take a long time and there + are many Atari environments. + """ + if env_name in ATARI_ENVS: + # these environments take a while for their non-determinism to show. + slow_random_envs = [ + "seals/Bowling-v5", + "seals/Frogger-v5", + "seals/KingKong-v5", + "seals/Koolaid-v5", + "seals/NameThisGame-v5", + ] + rollout_len = 100 if env_name not in slow_random_envs else 400 + num_seeds = 2 if env_name in ATARI_NO_FRAMESKIP_ENVS else 10 + envs.test_seed( + env, + env_name, + DETERMINISTIC_ENVS, + rollout_len=rollout_len, + num_seeds=num_seeds, + ) + else: + envs.test_seed(env, env_name, DETERMINISTIC_ENVS) def test_premature_step(self, env: gym.Env): """Tests if step() before reset() raises error.""" envs.test_premature_step(env, skip_fn=pytest.skip, raises_fn=pytest.raises) - def test_rollout_schema(self, env: gym.Env): - """Tests if environments have correct types on `step()` and `reset()`.""" - envs.test_rollout_schema(env) + def test_rollout_schema(self, env: gym.Env, env_name: str): + """Tests if environments have correct types on `step()` and `reset()`. + + Atari environments have a very long episode length (~100k observations), so in + the interest of time we do not run them to the end of their episodes or check + the return time of `env.step` after the end of the episode. + """ + if env_name in ATARI_ENVS: + envs.test_rollout_schema(env, max_steps=1_000, check_episode_ends=False) + else: + envs.test_rollout_schema(env) def test_render(self, env: gym.Env): """Tests `render()` supports modes specified in environment metadata."""