From 440d5a79a3ba4c8a7afc0a1bc7cd7d853d4ddea6 Mon Sep 17 00:00:00 2001 From: Juan Rocamonde Date: Sat, 1 Jul 2023 18:22:07 +0200 Subject: [PATCH 01/61] Initial commit --- README.md | 26 +-- pyproject.toml | 8 +- setup.py | 8 +- src/seals/__init__.py | 6 +- src/seals/atari.py | 17 +- src/seals/base_envs.py | 251 ++++++++++++-------------- src/seals/classic_control.py | 9 +- src/seals/diagnostics/__init__.py | 2 +- src/seals/diagnostics/early_term.py | 2 +- src/seals/diagnostics/init_shift.py | 4 +- src/seals/diagnostics/largest_sum.py | 12 +- src/seals/diagnostics/noisy_obs.py | 20 +- src/seals/diagnostics/parabola.py | 11 +- src/seals/diagnostics/proc_goal.py | 16 +- src/seals/diagnostics/random_trans.py | 18 +- src/seals/diagnostics/sort.py | 10 +- src/seals/mujoco.py | 2 +- src/seals/testing/envs.py | 47 ++--- src/seals/util.py | 82 ++++++--- tests/test_base_env.py | 7 +- tests/test_envs.py | 5 +- tests/test_mujoco_rl.py | 2 +- tests/test_util.py | 11 +- tests/test_wrappers.py | 26 +-- 24 files changed, 305 insertions(+), 297 deletions(-) diff --git a/README.md b/README.md index eacef50..e4078d8 100644 --- a/README.md +++ b/README.md @@ -7,39 +7,39 @@ **Status**: early beta. -*seals*, the Suite of Environments for Algorithms that Learn Specifications, is a toolkit for +_seals_, the Suite of Environments for Algorithms that Learn Specifications, is a toolkit for evaluating specification learning algorithms, such as reward or imitation learning. The environments are compatible with [Gym](https://github.com/openai/gym), but are designed to test algorithms that learn from user data, without requiring a procedurally specified reward function. -There are two types of environments in *seals*: +There are two types of environments in _seals_: - - **Diagnostic Tasks** which test individual facets of algorithm performance in isolation. - - **Renovated Environments**, adaptations of widely-used benchmarks such as MuJoCo continuous - control tasks and Atari games to be suitable for specification learning benchmarks. In particular, - we remove any side-channel sources of reward information from MuJoCo tasks, and give Atari games constant-length episodes (although most Atari environments have observations that include the score). +- **Diagnostic Tasks** which test individual facets of algorithm performance in isolation. +- **Renovated Environments**, adaptations of widely-used benchmarks such as MuJoCo continuous + control tasks and Atari games to be suitable for specification learning benchmarks. In particular, + we remove any side-channel sources of reward information from MuJoCo tasks, and give Atari games constant-length episodes (although most Atari environments have observations that include the score). + +_seals_ is under active development and we intend to add more categories of tasks soon. -*seals* is under active development and we intend to add more categories of tasks soon. - You may also be interested in our sister project [imitation](https://github.com/humancompatibleai/imitation/), providing implementations of a variety of imitation and reward learning algorithms. -Check out our [documentation](https://seals.readthedocs.io/en/latest/) for more information about *seals*. +Check out our [documentation](https://seals.readthedocs.io/en/latest/) for more information about _seals_. # Quickstart To install the latest release from PyPI, run: - + ```bash pip install seals ``` -All *seals* environments are available in the Gym registry. Simply import it and then use as you +All _seals_ environments are available in the Gym registry. Simply import it and then use as you would with your usual RL or specification learning algroithm: ```python -import gym +import gymnasium as gym import seals env = gym.make('seals/CartPole-v0') @@ -86,7 +86,7 @@ for type checking. ## Workflow Trivial changes (e.g. typo fixes) may be made directly by maintainers. Any non-trivial changes -must be proposed in a PR and approved by at least one maintainer. PRs must pass the continuous +must be proposed in a PR and approved by at least one maintainer. PRs must pass the continuous integration tests (CircleCI linting, type checking, unit tests and CodeCov) to be merged. It is often helpful to open an issue before proposing a PR, to allow for discussion of the design diff --git a/pyproject.toml b/pyproject.toml index f9bcbe8..bdde550 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,8 +6,8 @@ build-backend = "setuptools.build_meta" target-version = ["py38"] [[tool.mypy.overrides]] -module = [ - "gym.*", - "setuptools_scm.*", -] +module = ["gym.*", "setuptools_scm.*"] ignore_missing_imports = true + +[tool.ruff] +select = ["E", "F"] diff --git a/setup.py b/setup.py index 84ba62a..40af813 100644 --- a/setup.py +++ b/setup.py @@ -115,10 +115,6 @@ def get_readme() -> str: "pytest-xdist", "pytype", "stable-baselines3>=0.9.0", - # TODO(adam): remove pyglet pin once Gym upgraded to >0.21 - # Workaround for https://github.com/openai/gym/issues/2986 - # Discussed in https://github.com/HumanCompatibleAI/imitation/pull/603 - "pyglet==1.5.27", "setuptools_scm~=7.0.5", *ATARI_REQUIRE, ] @@ -140,7 +136,7 @@ def get_readme() -> str: packages=find_packages("src"), package_dir={"": "src"}, package_data={"seals": ["py.typed"]}, - install_requires=["gym", "numpy"], + install_requires=["gymnasium", "numpy"], tests_require=TESTS_REQUIRE, extras_require={ # recommended packages for development @@ -149,7 +145,7 @@ def get_readme() -> str: "test": TESTS_REQUIRE, # We'd like to specify `gym[mujoco]`, but this is a no-op when Gym is already # installed. See https://github.com/pypa/pip/issues/4957 for issue. - "mujoco": ["mujoco_py>=1.50, <2.0", "imageio"], + "mujoco": ["mujoco", "imageio"], "atari": ATARI_REQUIRE, }, url="https://github.com/HumanCompatibleAI/benchmark-environments", diff --git a/src/seals/__init__.py b/src/seals/__init__.py index 7f017e4..fc15b21 100644 --- a/src/seals/__init__.py +++ b/src/seals/__init__.py @@ -2,10 +2,10 @@ from importlib import metadata -import gym +import gymnasium as gym -from seals import atari, util import seals.diagnostics # noqa: F401 +from seals import atari, util try: __version__ = metadata.version("seals") @@ -38,5 +38,5 @@ # Atari -GYM_ATARI_ENV_SPECS = list(filter(atari._supported_atari_env, gym.envs.registry.all())) +GYM_ATARI_ENV_SPECS = list(filter(atari._supported_atari_env, gym.registry.values())) atari.register_atari_envs(GYM_ATARI_ENV_SPECS) diff --git a/src/seals/atari.py b/src/seals/atari.py index a2d896c..9da2cce 100644 --- a/src/seals/atari.py +++ b/src/seals/atari.py @@ -2,7 +2,8 @@ from typing import Dict, Iterable, Optional -import gym +import gymnasium as gym +from gymnasium.envs.registration import EnvSpec from seals.util import ( AutoResetWrapper, @@ -37,7 +38,7 @@ def _get_score_region(atari_env_id: str) -> Optional[MaskedRegionSpecifier]: def make_atari_env(atari_env_id: str, masked: bool) -> gym.Env: """Fixed-length, optionally masked-score variant of a given Atari environment.""" - env = AutoResetWrapper(gym.make(atari_env_id)) + env: gym.Env = AutoResetWrapper(gym.make(atari_env_id)) if masked: score_region = _get_score_region(atari_env_id) @@ -59,15 +60,15 @@ def _not_ram_or_det(env_id: str) -> bool: after_slash = slash_separated[-1] hyphen_separated = after_slash.split("-") assert len(hyphen_separated) > 1 - not_ram = not ("ram" in hyphen_separated[1]) - not_deterministic = not ("Deterministic" in env_id) + not_ram = "ram" not in hyphen_separated[1] + not_deterministic = "Deterministic" not in env_id return not_ram and not_deterministic -def _supported_atari_env(gym_spec: gym.envs.registration.EnvSpec) -> bool: +def _supported_atari_env(gym_spec: EnvSpec) -> bool: """Checks if a gym Atari environment is one of the ones we will support.""" is_atari = gym_spec.entry_point == "gym.envs.atari:AtariEnv" - v5_and_plain = gym_spec.id.endswith("-v5") and not ("NoFrameskip" in gym_spec.id) + v5_and_plain = gym_spec.id.endswith("-v5") and "NoFrameskip" not in gym_spec.id v4_and_no_frameskip = gym_spec.id.endswith("-v4") and "NoFrameskip" in gym_spec.id return ( is_atari @@ -76,7 +77,7 @@ def _supported_atari_env(gym_spec: gym.envs.registration.EnvSpec) -> bool: ) -def _seals_name(gym_spec: gym.envs.registration.EnvSpec, masked: bool) -> str: +def _seals_name(gym_spec: EnvSpec, masked: bool) -> str: """Makes a Gym ID for an Atari environment in the seals namespace.""" slash_separated = gym_spec.id.split("/") name = "seals/" + slash_separated[-1] @@ -88,7 +89,7 @@ def _seals_name(gym_spec: gym.envs.registration.EnvSpec, masked: bool) -> str: def register_atari_envs( - gym_atari_env_specs: Iterable[gym.envs.registration.EnvSpec], + gym_atari_env_specs: Iterable[EnvSpec], ) -> None: """Register masked and unmasked wrapped gym Atari environments.""" diff --git a/src/seals/base_envs.py b/src/seals/base_envs.py index c7b29c6..b690f14 100644 --- a/src/seals/base_envs.py +++ b/src/seals/base_envs.py @@ -1,20 +1,23 @@ """Base environment classes.""" import abc -from typing import Generic, Optional, Sequence, Tuple, TypeVar +from typing import Any, Generic, Optional, Tuple, TypeVar -import gym -from gym import spaces +import gymnasium as gym import numpy as np +import numpy.typing as npt +from gymnasium import spaces from seals import util -State = TypeVar("State") -Observation = TypeVar("Observation") -Action = TypeVar("Action") +StateType = TypeVar("StateType") +ObsType = TypeVar("ObsType") +ActType = TypeVar("ActType") -class ResettablePOMDP(gym.Env, abc.ABC, Generic[State, Observation, Action]): +class ResettablePOMDP( + gym.Env[ObsType, ActType], abc.ABC, Generic[StateType, ObsType, ActType] +): """ABC for POMDPs that are resettable. Specifically, these environments provide oracle access to sample from @@ -23,69 +26,37 @@ class ResettablePOMDP(gym.Env, abc.ABC, Generic[State, Observation, Action]): meet these criteria. """ - _state_space: gym.Space - _observation_space: gym.Space - _action_space: gym.Space - _cur_state: Optional[State] + state_space: spaces.Space[StateType] + + _cur_state: Optional[StateType] _n_actions_taken: Optional[int] - def __init__( - self, - *, - state_space: gym.Space, - observation_space: gym.Space, - action_space: gym.Space, - ): - """Build resettable (PO)MDP. - - Args: - state_space: gym.Space containing possible states. - observation_space: gym.Space containing possible observations. - action_space: gym.Space containing possible actions. - """ - self._state_space = state_space - self._observation_space = observation_space - self._action_space = action_space + def __init__(self): + """Build resettable (PO)MDP.""" self._cur_state = None self._n_actions_taken = None - self.seed() @abc.abstractmethod - def initial_state(self) -> State: + def initial_state(self) -> StateType: """Samples from the initial state distribution.""" @abc.abstractmethod - def transition(self, state: State, action: Action) -> State: + def transition(self, state: StateType, action: ActType) -> StateType: """Samples from transition distribution.""" @abc.abstractmethod - def reward(self, state: State, action: Action, new_state: State) -> float: + def reward(self, state: StateType, action: ActType, new_state: StateType) -> float: """Computes reward for a given transition.""" @abc.abstractmethod - def terminal(self, state: State, step: int) -> bool: + def terminal(self, state: StateType, step: int) -> bool: """Is the state terminal?""" @abc.abstractmethod - def obs_from_state(self, state: State) -> Observation: + def obs_from_state(self, state: StateType) -> ObsType: """Sample observation for given state.""" - @property - def state_space(self) -> gym.Space: - """State space. Often same as observation_space, but differs in POMDPs.""" - return self._state_space - - @property - def observation_space(self) -> gym.Space: - """Observation space. Return type of reset() and component of step().""" - return self._observation_space - - @property - def action_space(self) -> gym.Space: - """Action space. Parameter type of step().""" - return self._action_space - @property def n_actions_taken(self) -> int: """Number of steps taken so far.""" @@ -93,34 +64,36 @@ def n_actions_taken(self) -> int: return self._n_actions_taken @property - def state(self) -> State: + def state(self) -> StateType: """Current state.""" assert self._cur_state is not None return self._cur_state @state.setter - def state(self, state: State): + def state(self, state: StateType): """Set current state.""" if state not in self.state_space: raise ValueError(f"{state} not in {self.state_space}") self._cur_state = state - def seed(self, seed=None) -> Sequence[int]: - """Set random seed.""" - if seed is None: - # Gym API wants list of seeds to be returned for some reason, so - # generate a seed explicitly in this case - seed = np.random.randint(0, 1 << 31) - self.rand_state = np.random.RandomState(seed) - return [seed] - - def reset(self) -> Observation: + def reset( + self, + *, + seed: int | None = None, + options: dict[str, Any] | None = None, + ) -> tuple[ObsType, dict[str, Any]]: # type: ignore """Reset episode and return initial observation.""" + if options is not None: + raise ValueError("Options not supported.") + + super().reset(seed=seed) self.state = self.initial_state() self._n_actions_taken = 0 - return self.obs_from_state(self.state) + obs = self.obs_from_state(self.state) + info: dict[str, Any] = dict() + return obs, info - def step(self, action: Action) -> Tuple[Observation, float, bool, dict]: + def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]: """Transition state using given action.""" if self._cur_state is None or self._n_actions_taken is None: raise ValueError("Need to call reset() before first step()") @@ -133,16 +106,30 @@ def step(self, action: Action) -> Tuple[Observation, float, bool, dict]: assert obs in self.observation_space reward = self.reward(old_state, action, self.state) self._n_actions_taken += 1 - done = self.terminal(self.state, self.n_actions_taken) + terminated = self.terminal(self.state, self.n_actions_taken) + truncated = False infos = {"old_state": old_state, "new_state": self._cur_state} - return obs, reward, done, infos + return obs, reward, terminated, truncated, infos + + @property + def rand_state(self) -> np.random.Generator: + """Random state.""" + rand_state = self._np_random + if rand_state is None: + raise ValueError("Need to call reset() before accessing rand_state") + return rand_state -class ExposePOMDPStateWrapper(gym.Wrapper, Generic[State, Observation, Action]): +class ExposePOMDPStateWrapper( + gym.Wrapper[StateType, ActType, ObsType, ActType], + Generic[StateType, ObsType, ActType], +): """A wrapper that exposes the current state of the POMDP as the observation.""" - def __init__(self, env: ResettablePOMDP[State, Observation, Action]) -> None: + env: ResettablePOMDP[StateType, ObsType, ActType] + + def __init__(self, env: ResettablePOMDP[StateType, ObsType, ActType]) -> None: """Build wrapper. Args: @@ -151,51 +138,50 @@ def __init__(self, env: ResettablePOMDP[State, Observation, Action]) -> None: super().__init__(env) self._observation_space = env.state_space - def reset(self) -> State: + def reset( + self, seed: int | None = None, options: dict[str, Any] | None = None + ) -> Tuple[StateType, dict[str, Any]]: """Reset environment and return initial state.""" - self.env.reset() - return self.env.state + _, info = self.env.reset(seed=seed, options=options) + return self.env.state, info - def step(self, action) -> Tuple[State, float, bool, dict]: + def step(self, action) -> Tuple[StateType, float, bool, bool, dict]: """Transition state using given action.""" - obs, reward, done, info = self.env.step(action) - return self.env.state, reward, done, info + _, reward, terminated, truncated, info = self.env.step(action) + return self.env.state, reward, terminated, truncated, info class ResettableMDP( - ResettablePOMDP[State, State, Action], + ResettablePOMDP[StateType, StateType, ActType], abc.ABC, - Generic[State, Action], + Generic[StateType, ActType], ): """ABC for MDPs that are resettable.""" - def __init__( - self, - *, - state_space: gym.Space, - action_space: gym.Space, - ): - """Build resettable MDP. + @property + def observation_space(self) -> spaces.Space[StateType]: + """Observation space.""" + return self.state_space - Args: - state_space: gym.Space containing possible states. - action_space: gym.Space containing possible actions. - """ - super().__init__( - state_space=state_space, - observation_space=state_space, - action_space=action_space, - ) + @observation_space.setter + def observation_space(self, space: spaces.Space[StateType]): + """Set observation space.""" + self.state_space = space - def obs_from_state(self, state: State) -> State: + def obs_from_state(self, state: StateType) -> StateType: """Identity since observation == state in an MDP.""" return state +DiscreteSpaceInt = np.int64 + + # TODO(juan) this does not implement the .render() method, # so in theory it should not be instantiated directly. # Not sure why this is not raising an error? -class BaseTabularModelPOMDP(ResettablePOMDP[int, Observation, int]): +class BaseTabularModelPOMDP( + ResettablePOMDP[DiscreteSpaceInt, ObsType, DiscreteSpaceInt], Generic[ObsType] +): """Base class for tabular environments with known dynamics. This is the general class that also allows subclassing for creating @@ -236,6 +222,8 @@ def __init__( ValueError: `transition_matrix`, `reward_matrix` or `initial_state_dist` have shapes different to specified above. """ + super().__init__() + # The following matrices should conform to the shapes below: # transition matrix: n_states x n_actions x n_states @@ -278,43 +266,42 @@ def __init__( self.horizon = horizon self.initial_state_dist = initial_state_dist - super().__init__( - state_space=self._construct_state_space(), - action_space=self._construct_action_space(), - observation_space=self._construct_observation_space(), - ) - - def _construct_state_space(self) -> gym.Space: - return spaces.Discrete(self.state_dim) - - def _construct_action_space(self) -> gym.Space: - return spaces.Discrete(self.action_dim) - - @abc.abstractmethod - def _construct_observation_space(self) -> gym.Space: - pass # pragma: no cover + self.state_space = spaces.Discrete(self.state_dim) + self.action_space = spaces.Discrete(self.action_dim) - def initial_state(self) -> int: + def initial_state(self) -> DiscreteSpaceInt: """Samples from the initial state distribution.""" - return util.sample_distribution( - self.initial_state_dist, - random=self.rand_state, + return DiscreteSpaceInt( + util.sample_distribution( + self.initial_state_dist, + random=self.rand_state, + ) ) - def transition(self, state: int, action: int) -> int: + def transition( + self, state: DiscreteSpaceInt, action: DiscreteSpaceInt + ) -> DiscreteSpaceInt: """Samples from transition distribution.""" - return util.sample_distribution( - self.transition_matrix[state, action], - random=self.rand_state, + return DiscreteSpaceInt( + util.sample_distribution( + self.transition_matrix[state, action], + random=self.rand_state, + ) ) - def reward(self, state: int, action: int, new_state: int) -> float: + def reward( + self, + state: DiscreteSpaceInt, + action: DiscreteSpaceInt, + new_state: DiscreteSpaceInt, + ) -> float: """Computes reward for a given transition.""" inputs = (state, action, new_state)[: len(self.reward_matrix.shape)] return self.reward_matrix[inputs] - def terminal(self, state: int, n_actions_taken: int) -> bool: + def terminal(self, state: DiscreteSpaceInt, n_actions_taken: int) -> bool: """Checks if state is terminal.""" + del state return self.horizon is not None and n_actions_taken >= self.horizon @property @@ -323,7 +310,7 @@ def feature_matrix(self): # Construct lazily to save memory in algorithms that don't need features. if self._feature_matrix is None: n_states = self.state_space.n - self._feature_matrix = np.eye(n_states) + self._feature_matrix = np.eye(int(n_states)) return self._feature_matrix @property @@ -337,7 +324,12 @@ def action_dim(self) -> int: return self.transition_matrix.shape[1] -class TabularModelPOMDP(BaseTabularModelPOMDP[np.ndarray]): +ObsEntryType = TypeVar( + "ObsEntryType", bound=np.floating[Any] | np.integer[Any], covariant=True +) + + +class TabularModelPOMDP(BaseTabularModelPOMDP[np.ndarray], Generic[ObsEntryType]): """Tabular model POMDP. This class is specifically for environments where observation != state, @@ -349,13 +341,14 @@ class TabularModelPOMDP(BaseTabularModelPOMDP[np.ndarray]): a vector with self.obs_dim entries. """ - observation_matrix: np.ndarray + observation_matrix: npt.NDArray[ObsEntryType] + observation_space: spaces.Box def __init__( self, *, transition_matrix: np.ndarray, - observation_matrix: np.ndarray, + observation_matrix: npt.NDArray[ObsEntryType], reward_matrix: np.ndarray, horizon: Optional[int] = None, initial_state_dist: Optional[np.ndarray] = None, @@ -377,7 +370,6 @@ def __init__( f"observation_matrix.shape[0]: {observation_matrix.shape[0]}", ) - def _construct_observation_space(self) -> gym.Space: min_val: float max_val: float try: @@ -386,14 +378,14 @@ def _construct_observation_space(self) -> gym.Space: except ValueError: min_val = -np.inf max_val = np.inf - return spaces.Box( + self.observation_space = spaces.Box( low=min_val, high=max_val, shape=(self.obs_dim,), dtype=self.obs_dtype, ) - def obs_from_state(self, state: int) -> np.ndarray: + def obs_from_state(self, state: DiscreteSpaceInt) -> npt.NDArray[ObsEntryType]: """Computes observation from state.""" # Copy so it can't be mutated in-place (updates will be reflected in # self.observation_matrix!) @@ -407,12 +399,12 @@ def obs_dim(self) -> int: return self.observation_matrix.shape[1] @property - def obs_dtype(self) -> int: + def obs_dtype(self) -> np.dtype[ObsEntryType]: """Data type of observation vectors (e.g. np.float32).""" return self.observation_matrix.dtype -class TabularModelMDP(BaseTabularModelPOMDP[int]): +class TabularModelMDP(BaseTabularModelPOMDP[DiscreteSpaceInt]): """Tabular model MDP. A tabular model MDP is a tabular MDP where the transition and reward @@ -444,9 +436,6 @@ def __init__( initial_state_dist=initial_state_dist, ) - def obs_from_state(self, state: int) -> int: + def obs_from_state(self, state: DiscreteSpaceInt) -> DiscreteSpaceInt: """Identity since observation == state in an MDP.""" return state - - def _construct_observation_space(self) -> gym.Space: - return self._construct_state_space() diff --git a/src/seals/classic_control.py b/src/seals/classic_control.py index 9854095..7b1d3c9 100644 --- a/src/seals/classic_control.py +++ b/src/seals/classic_control.py @@ -2,14 +2,15 @@ import warnings -from gym import spaces -import gym.envs.classic_control import numpy as np +from gymnasium import spaces +from gymnasium.envs import classic_control + from seals import util -class FixedHorizonCartPole(gym.envs.classic_control.CartPoleEnv): +class FixedHorizonCartPole(classic_control.CartPoleEnv): """Fixed-length variant of CartPole-v1. Reward is 1.0 whenever the CartPole is an "ok" state (i.e. the pole is upright @@ -20,7 +21,7 @@ class FixedHorizonCartPole(gym.envs.classic_control.CartPoleEnv): """ def __init__(self): - """Builds FixedHorizonCartPole, modifying observation_space from Gym parent.""" + """Builds FixedHorizonCartPole, modifying observation_space from gym parent.""" super().__init__() high = [ diff --git a/src/seals/diagnostics/__init__.py b/src/seals/diagnostics/__init__.py index 74be80f..b3c9b28 100644 --- a/src/seals/diagnostics/__init__.py +++ b/src/seals/diagnostics/__init__.py @@ -1,6 +1,6 @@ """Simple diagnostic environments.""" -import gym +import gymnasium as gym gym.register( id="seals/Branching-v0", diff --git a/src/seals/diagnostics/early_term.py b/src/seals/diagnostics/early_term.py index 8c672bb..87287bd 100644 --- a/src/seals/diagnostics/early_term.py +++ b/src/seals/diagnostics/early_term.py @@ -53,7 +53,7 @@ def __init__(self, is_reward_positive: bool = True): reward_matrix=reward_matrix, ) - def terminal(self, state: int, n_actions_taken: int) -> bool: + def terminal(self, state: base_envs.DiscreteSpaceInt, n_actions_taken: int) -> bool: """Returns True if (and only if) in state 2.""" return bool(state == 2) diff --git a/src/seals/diagnostics/init_shift.py b/src/seals/diagnostics/init_shift.py index 086f561..1e47477 100644 --- a/src/seals/diagnostics/init_shift.py +++ b/src/seals/diagnostics/init_shift.py @@ -27,7 +27,7 @@ class InitShiftEnv(base_envs.TabularModelMDP): disambiguate this case. """ - def __init__(self, initial_state: int): + def __init__(self, initial_state: base_envs.DiscreteSpaceInt): """Constructs environment. Args: @@ -63,7 +63,7 @@ def __init__(self, initial_state: int): reward_matrix=reward_matrix, ) - def initial_state(self) -> int: + def initial_state(self) -> base_envs.DiscreteSpaceInt: """Returns initial state defined in constructor.""" return self._initial_state diff --git a/src/seals/diagnostics/largest_sum.py b/src/seals/diagnostics/largest_sum.py index d55021f..688588c 100644 --- a/src/seals/diagnostics/largest_sum.py +++ b/src/seals/diagnostics/largest_sum.py @@ -1,7 +1,7 @@ """Environment testing scalability to high-dimensional tasks.""" -from gym import spaces import numpy as np +from gymnasium import spaces from seals import base_envs @@ -23,12 +23,10 @@ def __init__(self, length: int = 50): Args: length: dimensionality of state space vector. """ + super().__init__() self._length = length - state_space = spaces.Box(low=0.0, high=1.0, shape=(length,)) - super().__init__( - state_space=state_space, - action_space=spaces.Discrete(2), - ) + self.state_space = spaces.Box(low=0.0, high=1.0, shape=(length,)) + self.action_space = spaces.Discrete(2) def terminal(self, state: np.ndarray, n_actions_taken: int) -> bool: """Always returns True, since this task should have a 1-timestep horizon.""" @@ -36,7 +34,7 @@ def terminal(self, state: np.ndarray, n_actions_taken: int) -> bool: def initial_state(self) -> np.ndarray: """Returns vector sampled uniformly in [0, 1]**L.""" - init_state = self.rand_state.rand(self._length) + init_state = self.rand_state.random((self._length,)) return init_state.astype(self.observation_space.dtype) def reward(self, state: np.ndarray, act: int, next_state: np.ndarray) -> float: diff --git a/src/seals/diagnostics/noisy_obs.py b/src/seals/diagnostics/noisy_obs.py index 8588416..a02f9a5 100644 --- a/src/seals/diagnostics/noisy_obs.py +++ b/src/seals/diagnostics/noisy_obs.py @@ -1,7 +1,7 @@ """Environment testing for robustness to noise.""" -from gym import spaces import numpy as np +from gymnasium import spaces from seals import base_envs, util @@ -23,6 +23,8 @@ def __init__(self, *, size: int = 5, noise_length: int = 20): size: width and height of gridworld. noise_length: dimension of noise vector in observation. """ + super().__init__() + self._size = size self._noise_length = noise_length self._goal = np.array([self._size // 2, self._size // 2]) @@ -34,14 +36,12 @@ def __init__(self, *, size: int = 5, noise_length: int = 20): ([size - 1, size - 1], np.full(self._noise_length, np.inf)), # type: ignore ) - super().__init__( - state_space=spaces.MultiDiscrete([size, size]), - action_space=spaces.Discrete(5), - observation_space=spaces.Box( - low=obs_box_low, - high=obs_box_high, - dtype=np.float32, - ), + self.state_space = spaces.MultiDiscrete([size, size]) + self.action_space = spaces.Discrete(5) + self.observation_space = spaces.Box( + low=obs_box_low, + high=obs_box_high, + dtype=np.float32, ) def terminal(self, state: np.ndarray, n_actions_taken: int) -> bool: @@ -52,7 +52,7 @@ def initial_state(self) -> np.ndarray: """Returns one of the grid's corners.""" n = self._size corners = np.array([[0, 0], [n - 1, 0], [0, n - 1], [n - 1, n - 1]]) - return corners[self.rand_state.randint(4)] + return corners[self.rand_state.integers(4)] def reward(self, state: np.ndarray, action: int, new_state: np.ndarray) -> float: """Returns +1.0 reward if state is the goal and 0.0 otherwise.""" diff --git a/src/seals/diagnostics/parabola.py b/src/seals/diagnostics/parabola.py index 07c9b8b..053eff1 100644 --- a/src/seals/diagnostics/parabola.py +++ b/src/seals/diagnostics/parabola.py @@ -1,7 +1,7 @@ """Environment testing for generalization in continuous spaces.""" -from gym import spaces import numpy as np +from gymnasium import spaces from seals import base_envs @@ -24,16 +24,15 @@ def __init__(self, x_step: float = 0.05, bounds: float = 5): bounds: limits coordinates, useful for keeping rewards in a small bounded range. """ + super().__init__() self._x_step = x_step self._bounds = bounds state_high = np.array([bounds, bounds, 1.0, 1.0, 1.0]) state_low = (-1) * state_high - super().__init__( - state_space=spaces.Box(low=state_low, high=state_high), - action_space=spaces.Box(low=(-2) * bounds, high=2 * bounds, shape=()), - ) + self.state_space = spaces.Box(low=state_low, high=state_high) + self.action_space = spaces.Box(low=(-2) * bounds, high=2 * bounds, shape=()) def terminal(self, state: int, n_actions_taken: int) -> bool: """Always returns False.""" @@ -41,7 +40,7 @@ def terminal(self, state: int, n_actions_taken: int) -> bool: def initial_state(self) -> np.ndarray: """Get state by sampling a random parabola.""" - a, b, c = -1 + 2 * self.rand_state.rand(3) + a, b, c = -1 + 2 * self.rand_state.random((3,)) x, y = 0, c return np.array([x, y, a, b, c], dtype=self.state_space.dtype) diff --git a/src/seals/diagnostics/proc_goal.py b/src/seals/diagnostics/proc_goal.py index e534b72..706a939 100644 --- a/src/seals/diagnostics/proc_goal.py +++ b/src/seals/diagnostics/proc_goal.py @@ -1,7 +1,7 @@ """Large gridworld with random agent and goal position.""" -from gym import spaces import numpy as np +from gymnasium import spaces from seals import base_envs, util @@ -28,13 +28,11 @@ def __init__(self, bounds: int = 100, distance: int = 10): generalization harder. distance: initial distance between agent and goal. """ + super().__init__() self._bounds = bounds self._distance = distance - - super().__init__( - state_space=spaces.Box(low=-np.inf, high=np.inf, shape=(4,)), - action_space=spaces.Discrete(5), - ) + self.state_space = spaces.Box(low=-np.inf, high=np.inf, shape=(4,)) + self.action_space = spaces.Discrete(5) def terminal(self, state: np.ndarray, n_actions_taken: int) -> bool: """Always returns False.""" @@ -42,11 +40,11 @@ def terminal(self, state: np.ndarray, n_actions_taken: int) -> bool: def initial_state(self) -> np.ndarray: """Samples random agent position and random goal.""" - pos = self.rand_state.randint(low=-self._bounds, high=self._bounds, size=(2,)) + pos = self.rand_state.integers(low=-self._bounds, high=self._bounds, size=(2,)) - x_dist = self.rand_state.randint(self._distance) + x_dist = self.rand_state.integers(self._distance) y_dist = self._distance - x_dist - random_signs = 2 * self.rand_state.randint(2, size=2) - 1 + random_signs = 2 * self.rand_state.integers(2, size=2) - 1 goal = pos + random_signs * (x_dist, y_dist) return np.concatenate([pos, goal]).astype(self.observation_space.dtype) diff --git a/src/seals/diagnostics/random_trans.py b/src/seals/diagnostics/random_trans.py index 09eb6b4..13dae26 100644 --- a/src/seals/diagnostics/random_trans.py +++ b/src/seals/diagnostics/random_trans.py @@ -46,7 +46,7 @@ def __init__( """ # this generator is ONLY for constructing the MDP, not for controlling # random outcomes during rollouts - rand_gen = np.random.RandomState(generator_seed) + rand_gen = np.random.default_rng(generator_seed) if random_obs: if obs_dim is None: @@ -87,7 +87,7 @@ def make_random_trans_mat( n_states, n_actions, max_branch_factor, - rand_state: Optional[np.random.RandomState] = None, + rand_state: Optional[np.random.Generator] = None, ) -> np.ndarray: """Make a 'random' transition matrix. @@ -110,14 +110,14 @@ def make_random_trans_mat( of transitioning to `next_s` after taking action `a` in state `s`. """ if rand_state is None: - rand_state = np.random.RandomState() + rand_state = np.random.default_rng() assert rand_state is not None out_mat = np.zeros((n_states, n_actions, n_states), dtype="float32") for start_state in range(n_states): for action in range(n_actions): # uniformly sample a number of successors in [1,max_branch_factor] # for this action - successors = rand_state.randint(1, max_branch_factor + 1) + successors = rand_state.integers(1, max_branch_factor + 1) next_states = rand_state.choice( n_states, size=(successors,), @@ -133,7 +133,7 @@ def make_random_trans_mat( def make_random_state_dist( n_avail: int, n_states: int, - rand_state: Optional[np.random.RandomState] = None, + rand_state: Optional[np.random.Generator] = None, ) -> np.ndarray: """Make a random initial state distribution over n_states. @@ -152,7 +152,7 @@ def make_random_state_dist( ValueError: If `n_avail` is not in the range `(0, n_states]`. """ # noqa: DAR402 if rand_state is None: - rand_state = np.random.RandomState() + rand_state = np.random.default_rng() assert rand_state is not None assert 0 < n_avail <= n_states init_dist = np.zeros((n_states,)) @@ -168,7 +168,7 @@ def make_obs_mat( n_states: int, is_random: bool, obs_dim: Optional[int] = None, - rand_state: Optional[np.random.RandomState] = None, + rand_state: Optional[np.random.Generator] = None, ) -> np.ndarray: """Makes an observation matrix with a single observation for each state. @@ -179,7 +179,7 @@ def make_obs_mat( If `False`, are unique one-hot vectors for each state. obs_dim (int or NoneType): Must be `None` if `is_random == False`. Otherwise, this must be set to the size of the random vectors. - rand_state (np.random.RandomState): Random number generator. + rand_state (np.random.Generator): Random number generator. Returns: A matrix of shape `(n_states, obs_dim if is_random else n_states)`. @@ -188,7 +188,7 @@ def make_obs_mat( ValueError: If ``is_random == False`` and ``obs_dim is not None``. """ if rand_state is None: - rand_state = np.random.RandomState() + rand_state = np.random.default_rng() assert rand_state is not None if is_random: if obs_dim is None: diff --git a/src/seals/diagnostics/sort.py b/src/seals/diagnostics/sort.py index d1c4313..6b1a88e 100644 --- a/src/seals/diagnostics/sort.py +++ b/src/seals/diagnostics/sort.py @@ -1,7 +1,7 @@ """Environment to sort a list using swap actions.""" -from gym import spaces import numpy as np +from gymnasium import spaces from seals import base_envs @@ -21,10 +21,9 @@ def __init__(self, length: int = 4): """ self._length = length - super().__init__( - state_space=spaces.Box(low=0, high=1.0, shape=(length,)), - action_space=spaces.MultiDiscrete([length, length]), - ) + super().__init__() + self.state_space = spaces.Box(low=0, high=1.0, shape=(length,)) + self.action_space = spaces.MultiDiscrete([length, length]) def terminal(self, state: np.ndarray, n_actions_taken: int) -> bool: """Always returns False.""" @@ -42,6 +41,7 @@ def reward( new_state: np.ndarray, ) -> float: """Rewards fully sorted lists, and new correct positions.""" + del action # This is not meant to be a potential shaping in the formal sense, # as it changes the trajectory returns (since we do not return # a fixed-potential state at termination). diff --git a/src/seals/mujoco.py b/src/seals/mujoco.py index 5cd2c5f..0e539ab 100644 --- a/src/seals/mujoco.py +++ b/src/seals/mujoco.py @@ -2,7 +2,7 @@ import functools -from gym.envs.mujoco import ( +from gymnasium.envs.mujoco import ( ant_v3, half_cheetah_v3, hopper_v3, diff --git a/src/seals/testing/envs.py b/src/seals/testing/envs.py index a927a50..58bb7e9 100644 --- a/src/seals/testing/envs.py +++ b/src/seals/testing/envs.py @@ -17,12 +17,12 @@ Tuple, ) -import gym +import gymnasium as gym import numpy as np -Step = Tuple[Any, Optional[float], bool, Mapping[str, Any]] +Step = Tuple[Any, Optional[float], bool, bool, Mapping[str, Any]] Rollout = Sequence[Step] -"""A sequence of 4-tuples (obs, rew, done, info) as returned by `get_rollout`.""" +"""A sequence of 4-tuples (obs, rew, terminated, truncated, info) as returned by `get_rollout`.""" def make_env_fixture( @@ -95,9 +95,9 @@ def get_rollout(env: gym.Env, actions: Iterable[Any]) -> Rollout: actions: the actions to perform. Returns: - A sequence of 4-tuples (obs, rew, done, info). + A sequence of 4-tuples (obs, rew, terminated, truncated, info). """ - ret: List[Step] = [(env.reset(), None, False, {})] + ret: List[Step] = [(env.reset(), None, False, False, {})] for act in actions: ret.append(env.step(act)) return ret @@ -110,11 +110,12 @@ def assert_equal_rollout(rollout_a: Rollout, rollout_b: Rollout) -> None: AssertionError if they are not equal. """ for step_a, step_b in zip(rollout_a, rollout_b): - ob_a, rew_a, done_a, info_a = step_a - ob_b, rew_b, done_b, info_b = step_b + ob_a, rew_a, terminated_a, truncated_a, info_a = step_a + ob_b, rew_b, terminated_b, truncated_b, info_b = step_b np.testing.assert_equal(ob_a, ob_b) assert rew_a == rew_b - assert done_a == done_b + assert terminated_a == terminated_b + assert truncated_a == truncated_b np.testing.assert_equal(info_a, info_b) @@ -155,13 +156,13 @@ def test_seed( env.action_space.seed(0) actions = [env.action_space.sample() for _ in range(rollout_len)] # With the same seed, should always get the same result - seeds = env.seed(42) + seeds = env.reset(seed=42) # output of env.seed should be a list, but atari environments return a tuple. assert isinstance(seeds, (list, tuple)) assert len(seeds) > 0 rollout_a = get_rollout(env, actions) - env.seed(42) + env.reset(seed=42) rollout_b = get_rollout(env, actions) assert_equal_rollout(rollout_a, rollout_b) @@ -171,9 +172,9 @@ def test_seed( # seeds should produce the same starting state. def different_seeds_same_rollout(seed1, seed2): new_actions = [env.action_space.sample() for _ in range(rollout_len)] - env.seed(seed1) + env.reset(seed=seed1) new_rollout_1 = get_rollout(env, new_actions) - env.seed(seed2) + env.reset(seed=seed2) new_rollout_2 = get_rollout(env, new_actions) return has_same_observations(new_rollout_1, new_rollout_2) @@ -192,15 +193,16 @@ def _check_obs(obs: np.ndarray, obs_space: gym.Space) -> None: assert obs in obs_space -def _sample_and_check(env: gym.Env, obs_space: gym.Space) -> bool: +def _sample_and_check(env: gym.Env, obs_space: gym.Space) -> Tuple[bool, bool]: """Sample from env and check return value is of valid type.""" act = env.action_space.sample() - obs, rew, done, info = env.step(act) + obs, rew, terminated, truncated, info = env.step(act) _check_obs(obs, obs_space) assert isinstance(rew, float) - assert isinstance(done, bool) + assert isinstance(terminated, bool) + assert isinstance(truncated, bool) assert isinstance(info, dict) - return done + return terminated, truncated def _is_mujoco_env(env: gym.Env) -> bool: @@ -228,16 +230,17 @@ def test_rollout_schema( AssertionError if test fails. """ obs_space = env.observation_space - obs = env.reset() + obs, _ = env.reset() _check_obs(obs, obs_space) + terminated = False for _ in range(max_steps): - done = _sample_and_check(env, obs_space) - if done: + terminated, _ = _sample_and_check(env, obs_space) + if terminated: break if check_episode_ends: - assert done, "did not get to end of episode" + assert terminated, "did not get to end of episode" for _ in range(steps_after_done): _sample_and_check(env, obs_space) @@ -352,5 +355,5 @@ def step(self, action): t, self.timestep = self.timestep, self.timestep + 1 obs = np.array(t, dtype=self.observation_space.dtype) rew = t * 10.0 - done = t == self.episode_length - return obs, rew, done, {} + terminated = t == self.episode_length + return obs, rew, terminated, False, {} diff --git a/src/seals/util.py b/src/seals/util.py index 60c0323..a7ffb1e 100644 --- a/src/seals/util.py +++ b/src/seals/util.py @@ -1,13 +1,17 @@ """Miscellaneous utilities.""" from dataclasses import dataclass -from typing import List, Optional, Sequence, Tuple, Union +from typing import Any, Generic, List, Optional, Sequence, SupportsFloat, Tuple, Union -import gym +import gymnasium as gym import numpy as np +import numpy.typing as npt +from gymnasium.core import ActType, ObsType, WrapperActType, WrapperObsType -class AutoResetWrapper(gym.Wrapper): +class AutoResetWrapper( + gym.Wrapper, Generic[WrapperObsType, WrapperActType, ObsType, ActType] +): """Hides done=True and auto-resets at the end of each episode. Depending on the flag 'discard_terminal_observation', either discards the terminal @@ -37,7 +41,9 @@ def __init__(self, env, discard_terminal_observation=True, reset_reward=0.0): self.reset_reward = reset_reward self.previous_done = False # Whether the previous step returned done=True. - def step(self, action): + def step( + self, action: WrapperActType + ) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]: """When done=True, returns done=False, then reset depending on flag. Depending on whether we are discarding the terminal observation, @@ -50,7 +56,9 @@ def step(self, action): else: return self._step_pad(action) - def _step_pad(self, action): + def _step_pad( + self, action: WrapperActType + ) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]: """When done=True, return done=False instead and return the terminal obs. The agent will then usually be asked to perform an action based on @@ -67,26 +75,31 @@ def _step_pad(self, action): """ if self.previous_done: self.previous_done = False + reset_obs, reset_info_dict = self.env.reset() + info = {"reset_info_dict": reset_info_dict} # This transition will only reset the environment, the action is ignored. - return self.env.reset(), self.reset_reward, False, {} + return reset_obs, self.reset_reward, False, False, info - obs, rew, done, info = self.env.step(action) - if done: + obs, rew, terminated, truncated, info = self.env.step(action) + if terminated: self.previous_done = True - return obs, rew, False, info + return obs, rew, False, truncated, info - def _step_discard(self, action): + def _step_discard( + self, action: WrapperActType + ) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]: """When done=True, returns done=False instead and automatically resets. When an automatic reset happens, the observation from reset is returned, and the overridden observation is stored in `info["terminal_observation"]`. """ - obs, rew, done, info = self.env.step(action) - if done: + obs, rew, terminated, truncated, info = self.env.step(action) + if terminated: info["terminal_observation"] = obs - obs = self.env.reset() - return obs, rew, False, info + obs, reset_info_dict = self.env.reset() + info["reset_info_dict"] = reset_info_dict + return obs, rew, False, truncated, info @dataclass @@ -100,7 +113,10 @@ class BoxRegion: MaskedRegionSpecifier = List[BoxRegion] -class MaskScoreWrapper(gym.Wrapper): +class MaskScoreWrapper( + gym.Wrapper[npt.NDArray, ActType, npt.NDArray, ActType], + Generic[ActType], +): """Mask a list of box-shaped regions in the observation to hide reward info. Intended for environments whose observations are raw pixels (like Atari @@ -130,19 +146,22 @@ def __init__( super().__init__(env) self.fill_value = np.array(fill_value, env.observation_space.dtype) + assert env.observation_space.shape is not None self.mask = np.ones(env.observation_space.shape, dtype=bool) for r in score_regions: if r.x[0] >= r.x[1] or r.y[0] >= r.y[1]: raise ValueError('Invalid region: "x" and "y" must be increasing.') self.mask[r.x[0] : r.x[1], r.y[0] : r.y[1]] = 0 - def _mask_obs(self, obs): + def _mask_obs(self, obs) -> npt.NDArray: return np.where(self.mask, obs, self.fill_value) - def step(self, action): - """Returns (obs, rew, done, info) with masked obs.""" - obs, rew, done, info = self.env.step(action) - return self._mask_obs(obs), rew, done, info + def step( + self, action: ActType + ) -> tuple[npt.NDArray, SupportsFloat, bool, bool, dict[str, Any]]: + """Returns (obs, rew, terminated, truncated, info) with masked obs.""" + obs, rew, terminated, truncated, info = self.env.step(action) + return self._mask_obs(obs), rew, terminated, truncated, info def reset(self, **kwargs): """Returns masked reset observation.""" @@ -174,9 +193,9 @@ def reset(self): return super().reset().astype(self.dtype) def step(self, action): - """Returns (obs, rew, done, info) with obs cast to self.dtype.""" - obs, rew, done, info = super().step(action) - return obs.astype(self.dtype), rew, done, info + """Returns (obs, rew, terminated, truncated, info) with obs cast to self.dtype.""" + obs, rew, terminated, truncated, info = super().step(action) + return obs.astype(self.dtype), rew, terminated, truncated, info class AbsorbAfterDoneWrapper(gym.Wrapper): @@ -228,8 +247,10 @@ def step(self, action): `rew` depend on initialization arguments. `info` is always an empty dictionary. """ if not self.at_absorb_state: - inner_obs, inner_rew, done, inner_info = self.env.step(action) - if done: + inner_obs, inner_rew, terminated, truncated, inner_info = self.env.step( + action + ) + if terminated: # Initialize the artificial absorb state, which we will repeatedly use # starting on the next call to `step()`. self.at_absorb_state = True @@ -245,26 +266,27 @@ def step(self, action): obs = self.absorb_obs_this_episode rew = self.absorb_reward info = {} + truncated = False - return obs, rew, False, info + return obs, rew, False, truncated, info def make_env_no_wrappers(env_name: str, **kwargs) -> gym.Env: - """Gym sometimes wraps envs in TimeLimit before returning from gym.make(). + """Gym sometimes wraps envs in TimeLimit before returning from gymnasium.make(). This helper method builds directly from spec to avoid this wrapper. """ - return gym.envs.registry.env_specs[env_name].make(**kwargs) + return gym.spec(env_name).make(**kwargs) def get_gym_max_episode_steps(env_name: str) -> Optional[int]: """Get the `max_episode_steps` attribute associated with a gym Spec.""" - return gym.envs.registry.env_specs[env_name].max_episode_steps + return gym.spec(env_name).max_episode_steps def sample_distribution( p: np.ndarray, - random: np.random.RandomState, + random: np.random.Generator, ) -> int: """Samples an integer with probabilities given by p.""" return random.choice(np.arange(len(p)), p=p) diff --git a/tests/test_base_env.py b/tests/test_base_env.py index 59a3e32..c91fe22 100644 --- a/tests/test_base_env.py +++ b/tests/test_base_env.py @@ -4,10 +4,9 @@ so the tests in this file focus on features unique to classes in `base_envs`. """ -import gym +import gymnasium as gym import numpy as np import pytest - from seals import base_envs from seals.testing import envs @@ -19,9 +18,9 @@ def __init__(self): """Build environment.""" nS = 3 nA = 2 - transition_matrix = np.random.rand(nS, nA, nS) + transition_matrix = np.random.random((nS, nA, nS)) transition_matrix /= transition_matrix.sum(axis=2)[:, :, None] - reward_matrix = np.random.rand(nS) + reward_matrix = np.random.random((nS,)) super().__init__( transition_matrix=transition_matrix, reward_matrix=reward_matrix, diff --git a/tests/test_envs.py b/tests/test_envs.py index b2e86d4..4f59762 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -2,11 +2,10 @@ from typing import List -import gym -from gym.envs import registration +import gymnasium as gym import pytest - import seals # noqa: F401 required for env registration +from gymnasium.envs import registration from seals.atari import SCORE_REGIONS, _get_score_region, _seals_name, make_atari_env from seals.testing import envs diff --git a/tests/test_mujoco_rl.py b/tests/test_mujoco_rl.py index b0e362c..55f66b6 100644 --- a/tests/test_mujoco_rl.py +++ b/tests/test_mujoco_rl.py @@ -2,7 +2,7 @@ from typing import Tuple -import gym +import gymnasium as gym import pytest import stable_baselines3 from stable_baselines3.common import evaluation diff --git a/tests/test_util.py b/tests/test_util.py index 5829ebb..9ed8ceb 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -2,10 +2,9 @@ import collections -import gym +import gymnasium as gym import numpy as np import pytest - from seals import GYM_ATARI_ENV_SPECS, util @@ -22,11 +21,11 @@ def test_mask_score_wrapper_enforces_spec(): def test_sample_distribution(): """Test util.sample_distribution.""" distr_size = 5 - distr = np.random.rand(distr_size) + distr = np.random.random((distr_size,)) distr /= distr.sum() n_samples = 1000 - rng = np.random.RandomState() + rng = np.random.default_rng() sample_count = collections.Counter( util.sample_distribution(distr, rng) for _ in range(n_samples) ) @@ -39,8 +38,8 @@ def test_sample_distribution(): # Same seed gives same samples assert all( - util.sample_distribution(distr, random=np.random.RandomState(seed)) - == util.sample_distribution(distr, random=np.random.RandomState(seed)) + util.sample_distribution(distr, random=np.random.default_rng(seed)) + == util.sample_distribution(distr, random=np.random.default_rng(seed)) for seed in range(20) ) diff --git a/tests/test_wrappers.py b/tests/test_wrappers.py index 06fac22..a9c1906 100644 --- a/tests/test_wrappers.py +++ b/tests/test_wrappers.py @@ -2,7 +2,6 @@ import numpy as np import pytest - from seals import util from seals.testing import envs @@ -31,10 +30,11 @@ def test_auto_reset_wrapper_pad(episode_length=3, n_steps=100, n_manual_reset=2) next_episode_end = episode_length for t in range(1, n_steps + 1): act = env.action_space.sample() - obs, rew, done, info = env.step(act) + obs, rew, terminated, truncated, info = env.step(act) # AutoResetWrapper overrides all done signals. - assert done is False + assert terminated is False + assert truncated is False if t == next_episode_end: # Unlike the AutoResetWrapper that discards terminal observations, @@ -80,11 +80,12 @@ def test_auto_reset_wrapper_discard(episode_length=3, n_steps=100, n_manual_rese for t in range(1, n_steps + 1): act = env.action_space.sample() - obs, rew, done, info = env.step(act) + obs, rew, terminated, truncated, info = env.step(act) expected_obs = t % episode_length assert obs == expected_obs - assert done is False + assert terminated is False + assert truncated is False if expected_obs == 0: # End of episode assert info.get("terminal_observation", None) == episode_length @@ -113,8 +114,9 @@ def test_absorb_repeat_custom_state( env.reset() for t in range(1, n_steps + 1): act = env.action_space.sample() - obs, rew, done, _ = env.step(act) - assert done is False + obs, rew, terminated, truncated, _ = env.step(act) + assert terminated is False + assert truncated is False if t > episode_length: expected_obs = absorb_obs expected_rew = absorb_reward @@ -134,8 +136,9 @@ def test_absorb_repeat_final_state(episode_length=6, n_steps=100, n_manual_reset env.reset() for t in range(1, n_steps + 1): act = env.action_space.sample() - obs, rew, done, _ = env.step(act) - assert done is False + obs, rew, terminated, truncated, _ = env.step(act) + assert terminated is False + assert truncated is False if t > episode_length: expected_obs = episode_length expected_rew = -1 @@ -161,8 +164,9 @@ def test_obs_cast(dtype: np.dtype, episode_length: int = 5): assert obs == 0 for t in range(1, episode_length + 1): act = env.action_space.sample() - obs, rew, done, _ = env.step(act) - assert done == (t == episode_length) + obs, rew, terminated, truncated, _ = env.step(act) + assert terminated == (t == episode_length) + assert truncated is False assert obs.dtype == dtype assert obs == t assert rew == t * 10.0 From 3ed6d2b86af3d7bc66aa41da078e9f83fac1670e Mon Sep 17 00:00:00 2001 From: Edoardo Pona Date: Thu, 27 Jul 2023 14:47:19 +0100 Subject: [PATCH 02/61] py38 compatible type hints --- src/seals/base_envs.py | 16 ++++++++-------- src/seals/util.py | 10 +++++----- tests/test_envs.py | 2 +- 3 files changed, 14 insertions(+), 14 deletions(-) diff --git a/src/seals/base_envs.py b/src/seals/base_envs.py index b690f14..1f218d9 100644 --- a/src/seals/base_envs.py +++ b/src/seals/base_envs.py @@ -1,7 +1,7 @@ """Base environment classes.""" import abc -from typing import Any, Generic, Optional, Tuple, TypeVar +from typing import Any, Generic, Optional, Tuple, TypeVar, Dict, Union import gymnasium as gym import numpy as np @@ -79,9 +79,9 @@ def state(self, state: StateType): def reset( self, *, - seed: int | None = None, - options: dict[str, Any] | None = None, - ) -> tuple[ObsType, dict[str, Any]]: # type: ignore + seed: Union[int, None] = None, + options: Union[Dict[str, Any], None]= None, + ) -> Tuple[ObsType, Dict[str, Any]]: # type: ignore """Reset episode and return initial observation.""" if options is not None: raise ValueError("Options not supported.") @@ -139,8 +139,8 @@ def __init__(self, env: ResettablePOMDP[StateType, ObsType, ActType]) -> None: self._observation_space = env.state_space def reset( - self, seed: int | None = None, options: dict[str, Any] | None = None - ) -> Tuple[StateType, dict[str, Any]]: + self, seed: Union[int, None] = None, options: Union[Dict[str, Any], None]= None + ) -> Tuple[StateType, Dict[str, Any]]: """Reset environment and return initial state.""" _, info = self.env.reset(seed=seed, options=options) return self.env.state, info @@ -325,7 +325,7 @@ def action_dim(self) -> int: ObsEntryType = TypeVar( - "ObsEntryType", bound=np.floating[Any] | np.integer[Any], covariant=True + "ObsEntryType", bound=Union[np.floating, np.integer], covariant=True ) @@ -399,7 +399,7 @@ def obs_dim(self) -> int: return self.observation_matrix.shape[1] @property - def obs_dtype(self) -> np.dtype[ObsEntryType]: + def obs_dtype(self) -> np.dtype: """Data type of observation vectors (e.g. np.float32).""" return self.observation_matrix.dtype diff --git a/src/seals/util.py b/src/seals/util.py index a7ffb1e..7561bf3 100644 --- a/src/seals/util.py +++ b/src/seals/util.py @@ -1,7 +1,7 @@ """Miscellaneous utilities.""" from dataclasses import dataclass -from typing import Any, Generic, List, Optional, Sequence, SupportsFloat, Tuple, Union +from typing import Any, Generic, List, Optional, Sequence, SupportsFloat, Tuple, Union, Dict import gymnasium as gym import numpy as np @@ -43,7 +43,7 @@ def __init__(self, env, discard_terminal_observation=True, reset_reward=0.0): def step( self, action: WrapperActType - ) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]: + ) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]: """When done=True, returns done=False, then reset depending on flag. Depending on whether we are discarding the terminal observation, @@ -58,7 +58,7 @@ def step( def _step_pad( self, action: WrapperActType - ) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]: + ) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]: """When done=True, return done=False instead and return the terminal obs. The agent will then usually be asked to perform an action based on @@ -87,7 +87,7 @@ def _step_pad( def _step_discard( self, action: WrapperActType - ) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]: + ) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]: """When done=True, returns done=False instead and automatically resets. When an automatic reset happens, the observation from reset is returned, @@ -158,7 +158,7 @@ def _mask_obs(self, obs) -> npt.NDArray: def step( self, action: ActType - ) -> tuple[npt.NDArray, SupportsFloat, bool, bool, dict[str, Any]]: + ) -> Tuple[npt.NDArray, SupportsFloat, bool, bool, Dict[str, Any]]: """Returns (obs, rew, terminated, truncated, info) with masked obs.""" obs, rew, terminated, truncated, info = self.env.step(action) return self._mask_obs(obs), rew, terminated, truncated, info diff --git a/tests/test_envs.py b/tests/test_envs.py index 4f59762..275f823 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -11,7 +11,7 @@ ENV_NAMES: List[str] = [ env_spec.id - for env_spec in registration.registry.all() + for env_spec in registration.registry.values() if env_spec.id.startswith("seals/") ] From 35feb652341259956eab5c8fe78a58ed1f6956a9 Mon Sep 17 00:00:00 2001 From: Edoardo Pona Date: Thu, 27 Jul 2023 17:56:27 +0100 Subject: [PATCH 03/61] gymnasium compatible reset --- src/seals/classic_control.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/seals/classic_control.py b/src/seals/classic_control.py index 7b1d3c9..eb07e20 100644 --- a/src/seals/classic_control.py +++ b/src/seals/classic_control.py @@ -15,9 +15,10 @@ class FixedHorizonCartPole(classic_control.CartPoleEnv): Reward is 1.0 whenever the CartPole is an "ok" state (i.e. the pole is upright and the cart is on the screen). Otherwise reward is 0.0. - - Done is always False. (Though note that by default, this environment is wrapped - in `TimeLimit` with max steps 500.) + + Terminated is always False. + By default, this environment is wrapped in 'TimeLimit' with max steps 500 + Truncation is handled by that. """ def __init__(self): @@ -33,9 +34,10 @@ def __init__(self): high = np.array(high) self.observation_space = spaces.Box(-high, high, dtype=np.float32) - def reset(self): + def reset(self, seed=None, options={}): """Reset for FixedHorizonCartPole.""" - return super().reset().astype(np.float32) + observation, info = super().reset(seed=seed, options=options) + return observation.astype(np.float32), info def step(self, action): """Step function for FixedHorizonCartPole.""" @@ -56,7 +58,7 @@ def step(self, action): ) rew = 1.0 if state_ok else 0.0 - return np.array(self.state, dtype=np.float32), rew, False, {} + return np.array(self.state, dtype=np.float32), rew, False, False, {} def mountain_car(): From 4696597628bfc77d7257f7424b6cd4ac0dbdf028 Mon Sep 17 00:00:00 2001 From: Edoardo Pona Date: Fri, 28 Jul 2023 12:51:03 +0100 Subject: [PATCH 04/61] gymnasium compatibility changes --- src/seals/diagnostics/random_trans.py | 5 ++++- src/seals/util.py | 5 +++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/src/seals/diagnostics/random_trans.py b/src/seals/diagnostics/random_trans.py index 13dae26..ec8f067 100644 --- a/src/seals/diagnostics/random_trans.py +++ b/src/seals/diagnostics/random_trans.py @@ -72,7 +72,10 @@ def __init__( n_states=n_states, rand_state=rand_gen, ) - self.reward_weights = rand_gen.randn(observation_matrix.shape[-1]) + + self.reward_weights = rand_gen.normal( + 0, 1, size=(observation_matrix.shape[-1],) + ) reward_matrix = observation_matrix @ self.reward_weights super().__init__( transition_matrix=transition_matrix, diff --git a/src/seals/util.py b/src/seals/util.py index 7561bf3..c958684 100644 --- a/src/seals/util.py +++ b/src/seals/util.py @@ -188,9 +188,10 @@ def __init__(self, env: gym.Env, dtype: np.dtype): super().__init__(env) self.dtype = dtype - def reset(self): + def reset(self, seed=None): """Returns reset observation, cast to self.dtype.""" - return super().reset().astype(self.dtype) + obs, info = super().reset(seed=seed) + return obs.astype(self.dtype), info def step(self, action): """Returns (obs, rew, terminated, truncated, info) with obs cast to self.dtype.""" From bb20d549a15ba3df3666c61af60fe815e5e8c0fb Mon Sep 17 00:00:00 2001 From: Edoardo Pona Date: Fri, 28 Jul 2023 18:21:43 +0100 Subject: [PATCH 05/61] gymnasium compatible reset and random --- src/seals/diagnostics/noisy_obs.py | 2 +- src/seals/testing/envs.py | 9 +++++---- tests/test_wrappers.py | 6 +++--- 3 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/seals/diagnostics/noisy_obs.py b/src/seals/diagnostics/noisy_obs.py index a02f9a5..e594c70 100644 --- a/src/seals/diagnostics/noisy_obs.py +++ b/src/seals/diagnostics/noisy_obs.py @@ -69,5 +69,5 @@ def transition(self, state: np.ndarray, action: int) -> np.ndarray: def obs_from_state(self, state: np.ndarray) -> np.ndarray: """Returns (x, y) concatenated with Gaussian noise.""" - noise_vector = self.rand_state.randn(self._noise_length) + noise_vector = self.rand_state.normal(size=self._noise_length) return np.concatenate([state, noise_vector]).astype(np.float32) diff --git a/src/seals/testing/envs.py b/src/seals/testing/envs.py index 58bb7e9..7a1e838 100644 --- a/src/seals/testing/envs.py +++ b/src/seals/testing/envs.py @@ -95,9 +95,10 @@ def get_rollout(env: gym.Env, actions: Iterable[Any]) -> Rollout: actions: the actions to perform. Returns: - A sequence of 4-tuples (obs, rew, terminated, truncated, info). + A sequence of 5-tuples (obs, rew, terminated, truncated, info). """ - ret: List[Step] = [(env.reset(), None, False, False, {})] + obs, info = env.reset() + ret: List[Step] = [(obs, None, False, False, {})] for act in actions: ret.append(env.step(act)) return ret @@ -338,10 +339,10 @@ def __init__(self, episode_length: int = 5): self.episode_length = episode_length self.timestep = None - def reset(self): + def reset(self, seed=None, options={}): """Reset method for CountingEnv.""" t, self.timestep = 0, 1 - return np.array(t, dtype=self.observation_space.dtype) + return np.array(t, dtype=self.observation_space.dtype), {} def step(self, action): """Step method for CountingEnv.""" diff --git a/tests/test_wrappers.py b/tests/test_wrappers.py index a9c1906..63e50c3 100644 --- a/tests/test_wrappers.py +++ b/tests/test_wrappers.py @@ -22,7 +22,7 @@ def test_auto_reset_wrapper_pad(episode_length=3, n_steps=100, n_manual_reset=2) ) for _ in range(n_manual_reset): - obs = env.reset() + obs, info = env.reset() assert obs == 0 # We count the number of episodes, so we can sanity check the padding. @@ -75,7 +75,7 @@ def test_auto_reset_wrapper_discard(episode_length=3, n_steps=100, n_manual_rese ) for _ in range(n_manual_reset): - obs = env.reset() + obs, info = env.reset() assert obs == 0 for t in range(1, n_steps + 1): @@ -159,7 +159,7 @@ def test_obs_cast(dtype: np.dtype, episode_length: int = 5): env = envs.CountingEnv(episode_length=episode_length) env = util.ObsCastWrapper(env, dtype) - obs = env.reset() + obs, info = env.reset() assert obs.dtype == dtype assert obs == 0 for t in range(1, episode_length + 1): From b2f842120087da888b616d70f8308d4fe18a865d Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Wed, 2 Aug 2023 14:58:58 +0200 Subject: [PATCH 06/61] Make type annotations python 3.8 compatible. --- src/seals/base_envs.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/seals/base_envs.py b/src/seals/base_envs.py index 1f218d9..db8561b 100644 --- a/src/seals/base_envs.py +++ b/src/seals/base_envs.py @@ -79,9 +79,9 @@ def state(self, state: StateType): def reset( self, *, - seed: Union[int, None] = None, - options: Union[Dict[str, Any], None]= None, - ) -> Tuple[ObsType, Dict[str, Any]]: # type: ignore + seed: Optional[int] = None, + options: Optional[Dict[str, Any]] = None, + ) -> Tuple[ObsType, Dict[str, Any]]: """Reset episode and return initial observation.""" if options is not None: raise ValueError("Options not supported.") @@ -90,7 +90,7 @@ def reset( self.state = self.initial_state() self._n_actions_taken = 0 obs = self.obs_from_state(self.state) - info: dict[str, Any] = dict() + info: Dict[str, Any] = dict() return obs, info def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]: @@ -139,7 +139,7 @@ def __init__(self, env: ResettablePOMDP[StateType, ObsType, ActType]) -> None: self._observation_space = env.state_space def reset( - self, seed: Union[int, None] = None, options: Union[Dict[str, Any], None]= None + self, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None ) -> Tuple[StateType, Dict[str, Any]]: """Reset environment and return initial state.""" _, info = self.env.reset(seed=seed, options=options) @@ -399,7 +399,7 @@ def obs_dim(self) -> int: return self.observation_matrix.shape[1] @property - def obs_dtype(self) -> np.dtype: + def obs_dtype(self) -> np.dtype: """Data type of observation vectors (e.g. np.float32).""" return self.observation_matrix.dtype From 030ae677a0f8def41d004686cf46abe5d4efed1a Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Wed, 2 Aug 2023 14:59:29 +0200 Subject: [PATCH 07/61] Fix some grammar issues. --- src/seals/base_envs.py | 4 ++-- src/seals/classic_control.py | 12 ++++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/src/seals/base_envs.py b/src/seals/base_envs.py index db8561b..82e53af 100644 --- a/src/seals/base_envs.py +++ b/src/seals/base_envs.py @@ -71,7 +71,7 @@ def state(self) -> StateType: @state.setter def state(self, state: StateType): - """Set current state.""" + """Set the current state.""" if state not in self.state_space: raise ValueError(f"{state} not in {self.state_space}") self._cur_state = state @@ -82,7 +82,7 @@ def reset( seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None, ) -> Tuple[ObsType, Dict[str, Any]]: - """Reset episode and return initial observation.""" + """Reset the episode and return initial observation.""" if options is not None: raise ValueError("Options not supported.") diff --git a/src/seals/classic_control.py b/src/seals/classic_control.py index eb07e20..f61c4ce 100644 --- a/src/seals/classic_control.py +++ b/src/seals/classic_control.py @@ -13,12 +13,12 @@ class FixedHorizonCartPole(classic_control.CartPoleEnv): """Fixed-length variant of CartPole-v1. - Reward is 1.0 whenever the CartPole is an "ok" state (i.e. the pole is upright - and the cart is on the screen). Otherwise reward is 0.0. - - Terminated is always False. - By default, this environment is wrapped in 'TimeLimit' with max steps 500 - Truncation is handled by that. + Reward is 1.0 whenever the CartPole is an "ok" state (i.e., the pole is upright + and the cart is on the screen). Otherwise, reward is 0.0. + + Terminated is always False. + By default, this environment is wrapped in 'TimeLimit' with max steps 500, + Truncation is handled by that. """ def __init__(self): From 2f4af9d28843d33fba352164e43fc09af3159881 Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Wed, 2 Aug 2023 15:00:16 +0200 Subject: [PATCH 08/61] Raise RuntimeErrors and ValueErrors in the proper places. --- src/seals/base_envs.py | 6 +++--- src/seals/util.py | 3 ++- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/seals/base_envs.py b/src/seals/base_envs.py index 82e53af..448e2fb 100644 --- a/src/seals/base_envs.py +++ b/src/seals/base_envs.py @@ -84,7 +84,7 @@ def reset( ) -> Tuple[ObsType, Dict[str, Any]]: """Reset the episode and return initial observation.""" if options is not None: - raise ValueError("Options not supported.") + raise NotImplementedError("Options not supported.") super().reset(seed=seed) self.state = self.initial_state() @@ -96,7 +96,7 @@ def reset( def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]: """Transition state using given action.""" if self._cur_state is None or self._n_actions_taken is None: - raise ValueError("Need to call reset() before first step()") + raise RuntimeError("Need to call reset() before first step()") if action not in self.action_space: raise ValueError(f"{action} not in {self.action_space}") @@ -117,7 +117,7 @@ def rand_state(self) -> np.random.Generator: """Random state.""" rand_state = self._np_random if rand_state is None: - raise ValueError("Need to call reset() before accessing rand_state") + raise RuntimeError("Need to call reset() before accessing rand_state") return rand_state diff --git a/src/seals/util.py b/src/seals/util.py index c958684..de460af 100644 --- a/src/seals/util.py +++ b/src/seals/util.py @@ -146,7 +146,8 @@ def __init__( super().__init__(env) self.fill_value = np.array(fill_value, env.observation_space.dtype) - assert env.observation_space.shape is not None + if env.observation_space.shape is None: + raise ValueError("Observation space must have a shape.") self.mask = np.ones(env.observation_space.shape, dtype=bool) for r in score_regions: if r.x[0] >= r.x[1] or r.y[0] >= r.y[1]: From 93292c51fd4dea80f10b9ac7ab0b433f9c448cb6 Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Thu, 3 Aug 2023 12:52:55 +0200 Subject: [PATCH 09/61] Undoing unrelated formatting fixes in the readme. --- README.md | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index e4078d8..8b02209 100644 --- a/README.md +++ b/README.md @@ -7,35 +7,35 @@ **Status**: early beta. -_seals_, the Suite of Environments for Algorithms that Learn Specifications, is a toolkit for +*seals*, the Suite of Environments for Algorithms that Learn Specifications, is a toolkit for evaluating specification learning algorithms, such as reward or imitation learning. The environments are compatible with [Gym](https://github.com/openai/gym), but are designed to test algorithms that learn from user data, without requiring a procedurally specified reward function. -There are two types of environments in _seals_: +There are two types of environments in *seals*: -- **Diagnostic Tasks** which test individual facets of algorithm performance in isolation. -- **Renovated Environments**, adaptations of widely-used benchmarks such as MuJoCo continuous - control tasks and Atari games to be suitable for specification learning benchmarks. In particular, - we remove any side-channel sources of reward information from MuJoCo tasks, and give Atari games constant-length episodes (although most Atari environments have observations that include the score). - -_seals_ is under active development and we intend to add more categories of tasks soon. + - **Diagnostic Tasks** which test individual facets of algorithm performance in isolation. + - **Renovated Environments**, adaptations of widely-used benchmarks such as MuJoCo continuous + control tasks and Atari games to be suitable for specification learning benchmarks. In particular, + we remove any side-channel sources of reward information from MuJoCo tasks, and give Atari games constant-length episodes (although most Atari environments have observations that include the score). +*seals* is under active development and we intend to add more categories of tasks soon. + You may also be interested in our sister project [imitation](https://github.com/humancompatibleai/imitation/), providing implementations of a variety of imitation and reward learning algorithms. -Check out our [documentation](https://seals.readthedocs.io/en/latest/) for more information about _seals_. +Check out our [documentation](https://seals.readthedocs.io/en/latest/) for more information about *seals*. # Quickstart To install the latest release from PyPI, run: - + ```bash pip install seals ``` -All _seals_ environments are available in the Gym registry. Simply import it and then use as you +All *seals* environments are available in the Gym registry. Simply import it and then use as you would with your usual RL or specification learning algroithm: ```python @@ -86,7 +86,7 @@ for type checking. ## Workflow Trivial changes (e.g. typo fixes) may be made directly by maintainers. Any non-trivial changes -must be proposed in a PR and approved by at least one maintainer. PRs must pass the continuous +must be proposed in a PR and approved by at least one maintainer. PRs must pass the continuous integration tests (CircleCI linting, type checking, unit tests and CodeCov) to be merged. It is often helpful to open an issue before proposing a PR, to allow for discussion of the design From a3520a670afb406082908b814a8d9b682be6b1a8 Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Thu, 3 Aug 2023 12:54:29 +0200 Subject: [PATCH 10/61] Remove unused ruff configuration. --- pyproject.toml | 3 --- 1 file changed, 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index bdde550..d1ddaf0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,6 +8,3 @@ target-version = ["py38"] [[tool.mypy.overrides]] module = ["gym.*", "setuptools_scm.*"] ignore_missing_imports = true - -[tool.ruff] -select = ["E", "F"] From 364d0e8e87f613361a52705393e4a4207997cf36 Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Thu, 3 Aug 2023 12:58:12 +0200 Subject: [PATCH 11/61] Add Adams wording suggestions. --- src/seals/classic_control.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/seals/classic_control.py b/src/seals/classic_control.py index f61c4ce..4ad1c7c 100644 --- a/src/seals/classic_control.py +++ b/src/seals/classic_control.py @@ -18,7 +18,7 @@ class FixedHorizonCartPole(classic_control.CartPoleEnv): Terminated is always False. By default, this environment is wrapped in 'TimeLimit' with max steps 500, - Truncation is handled by that. + which sets `truncated` to true after that many steps. """ def __init__(self): From baacaca6283fcef516f5412a046c6a4b292a26f9 Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Tue, 8 Aug 2023 14:39:36 +0200 Subject: [PATCH 12/61] switch to alpha-version of the circle-ci image (reverst this before merge) --- .circleci/config.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 20ec919..49c49b0 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -8,7 +8,7 @@ orbs: defaults: &defaults docker: - - image: humancompatibleai/seals:base + - image: humancompatibleai/seals:base-alpha auth: username: $DOCKERHUB_USERNAME password: $DOCKERHUB_PASSWORD From c5cbe30695bd6610c5bdce07f353a638c1a4b703 Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Fri, 11 Aug 2023 11:59:15 +0200 Subject: [PATCH 13/61] Update Xdummy-entrypoint.py to python3 --- ci/Xdummy-entrypoint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ci/Xdummy-entrypoint.py b/ci/Xdummy-entrypoint.py index f1abd5c..1a45388 100755 --- a/ci/Xdummy-entrypoint.py +++ b/ci/Xdummy-entrypoint.py @@ -1,4 +1,4 @@ -#!/usr/bin/python +#!/usr/bin/python3 # Adapted from https://github.com/openai/mujoco-py/blob/master/vendor/Xdummy-entrypoint # Copyright OpenAI; MIT License From b684e88f2c1ca7842e872094326a4aa18f25e6cb Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Fri, 11 Aug 2023 12:25:34 +0200 Subject: [PATCH 14/61] Update Dockerfile to Ubuntu 20.04 and add ssh. --- Dockerfile | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Dockerfile b/Dockerfile index a190ca8..01cb4aa 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,6 +1,6 @@ # base stage contains just binary dependencies. # This is used in the CI build. -FROM nvidia/cuda:10.0-runtime-ubuntu18.04 AS base +FROM nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu20.04 AS base ARG DEBIAN_FRONTEND=noninteractive RUN apt-get update -q \ @@ -9,6 +9,7 @@ RUN apt-get update -q \ curl \ ffmpeg \ git \ + ssh \ libgl1-mesa-dev \ libgl1-mesa-glx \ libglew-dev \ From a1ab629992c0b87ef711264f5b3bb33e44a65fd2 Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Fri, 11 Aug 2023 14:40:45 +0200 Subject: [PATCH 15/61] Dont mention gym in inline comment but gymnasium. --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 40af813..75c2d21 100644 --- a/setup.py +++ b/setup.py @@ -143,7 +143,7 @@ def get_readme() -> str: "dev": ["ipdb", "jupyter", *TESTS_REQUIRE, *DOCS_REQUIRE], "docs": DOCS_REQUIRE, "test": TESTS_REQUIRE, - # We'd like to specify `gym[mujoco]`, but this is a no-op when Gym is already + # We'd like to specify `gymnasium[mujoco]`, but this is a no-op when Gym is already # installed. See https://github.com/pypa/pip/issues/4957 for issue. "mujoco": ["mujoco", "imageio"], "atari": ATARI_REQUIRE, From 0e56b4966b8c8f945f0c10ebf1c7ef41fc5c0306 Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Fri, 11 Aug 2023 16:33:36 +0200 Subject: [PATCH 16/61] Absorb terminated AND truncated steps. --- src/seals/util.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/seals/util.py b/src/seals/util.py index de460af..9b2f4f1 100644 --- a/src/seals/util.py +++ b/src/seals/util.py @@ -249,28 +249,26 @@ def step(self, action): `rew` depend on initialization arguments. `info` is always an empty dictionary. """ if not self.at_absorb_state: - inner_obs, inner_rew, terminated, truncated, inner_info = self.env.step( + obs, rew, terminated, truncated, info = self.env.step( action ) - if terminated: + if terminated or truncated: # Initialize the artificial absorb state, which we will repeatedly use # starting on the next call to `step()`. self.at_absorb_state = True if self.absorb_obs_default is None: - self.absorb_obs_this_episode = inner_obs + self.absorb_obs_this_episode = obs else: self.absorb_obs_this_episode = self.absorb_obs_default - obs, rew, info = inner_obs, inner_rew, inner_info else: assert self.absorb_obs_this_episode is not None assert self.absorb_reward is not None obs = self.absorb_obs_this_episode rew = self.absorb_reward info = {} - truncated = False - return obs, rew, False, truncated, info + return obs, rew, False, False, info def make_env_no_wrappers(env_name: str, **kwargs) -> gym.Env: From 034307f33fc3af270d4877566deba5f450edb82c Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Fri, 11 Aug 2023 16:56:53 +0200 Subject: [PATCH 17/61] Treat done == (terminated or truncated) and stop mentioning done in the documentation. --- src/seals/classic_control.py | 3 ++- src/seals/testing/envs.py | 14 +++++++------- src/seals/util.py | 36 ++++++++++++++++++++---------------- tests/test_base_env.py | 2 +- tests/test_wrappers.py | 2 +- 5 files changed, 31 insertions(+), 26 deletions(-) diff --git a/src/seals/classic_control.py b/src/seals/classic_control.py index 4ad1c7c..0fddd71 100644 --- a/src/seals/classic_control.py +++ b/src/seals/classic_control.py @@ -42,7 +42,8 @@ def reset(self, seed=None, options={}): def step(self, action): """Step function for FixedHorizonCartPole.""" with warnings.catch_warnings(): - # Filter out CartPoleEnv warning for calling step() beyond done=True. + # Filter out CartPoleEnv warning for calling step() beyond + # terminated=True or truncated=True warnings.filterwarnings("ignore", ".*You are calling.*") super().step(action) diff --git a/src/seals/testing/envs.py b/src/seals/testing/envs.py index 7a1e838..b0b44a9 100644 --- a/src/seals/testing/envs.py +++ b/src/seals/testing/envs.py @@ -212,7 +212,7 @@ def _is_mujoco_env(env: gym.Env) -> bool: def test_rollout_schema( env: gym.Env, - steps_after_done: int = 10, + steps_after_terminated: int = 10, max_steps: int = 10_000, check_episode_ends: bool = True, ) -> None: @@ -220,12 +220,12 @@ def test_rollout_schema( Args: env: The environment to test. - steps_after_done: The number of steps to take after `done` is True, the nominal - episode termination. This is an abuse of the Gym API, but we would like the - environments to handle this case gracefully. - max_steps: Test fails if we do not get `done` after this many timesteps. + steps_after_terminated: The number of steps to take after `terminated` is True, + the nominal episode termination. This is an abuse of the Gym API, + but we would like the environments to handle this case gracefully. + max_steps: Test fails if we do not get `terminated` after this many timesteps. check_episode_ends: Check that episode ends after `max_steps` steps, and that - steps taken after `done` is true are of the correct type. + steps taken after `terminated` is true are of the correct type. Raises: AssertionError if test fails. @@ -243,7 +243,7 @@ def test_rollout_schema( if check_episode_ends: assert terminated, "did not get to end of episode" - for _ in range(steps_after_done): + for _ in range(steps_after_terminated): _sample_and_check(env, obs_space) diff --git a/src/seals/util.py b/src/seals/util.py index 9b2f4f1..026b9f8 100644 --- a/src/seals/util.py +++ b/src/seals/util.py @@ -12,7 +12,8 @@ class AutoResetWrapper( gym.Wrapper, Generic[WrapperObsType, WrapperActType, ObsType, ActType] ): - """Hides done=True and auto-resets at the end of each episode. + """Hides terminated=True and truncated=True and auto-resets at the end of each + episode. Depending on the flag 'discard_terminal_observation', either discards the terminal observation or pads with an additional 'reset transition'. The former is the default @@ -44,7 +45,8 @@ def __init__(self, env, discard_terminal_observation=True, reset_reward=0.0): def step( self, action: WrapperActType ) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]: - """When done=True, returns done=False, then reset depending on flag. + """When terminated or truncated, resets the environment and returns False + for terminated and truncated. Depending on whether we are discarding the terminal observation, either resets the environment and discards, @@ -59,7 +61,8 @@ def step( def _step_pad( self, action: WrapperActType ) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]: - """When done=True, return done=False instead and return the terminal obs. + """When terminated or truncated, return False for both instead and return the + terminal obs. The agent will then usually be asked to perform an action based on the terminal observation. In the next step, this final action will be ignored @@ -81,25 +84,24 @@ def _step_pad( return reset_obs, self.reset_reward, False, False, info obs, rew, terminated, truncated, info = self.env.step(action) - if terminated: - self.previous_done = True - return obs, rew, False, truncated, info + self.previous_done = terminated or truncated + return obs, rew, False, False, info def _step_discard( self, action: WrapperActType ) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]: - """When done=True, returns done=False instead and automatically resets. + """When terminated or truncated, return False for both and automatically reset. When an automatic reset happens, the observation from reset is returned, and the overridden observation is stored in `info["terminal_observation"]`. """ obs, rew, terminated, truncated, info = self.env.step(action) - if terminated: + if terminated or truncated: info["terminal_observation"] = obs obs, reset_info_dict = self.env.reset() info["reset_info_dict"] = reset_info_dict - return obs, rew, False, truncated, info + return obs, rew, False, False, info @dataclass @@ -203,8 +205,9 @@ def step(self, action): class AbsorbAfterDoneWrapper(gym.Wrapper): """Transition into absorbing state instead of episode termination. - When the environment being wrapped returns `done=True`, we return an absorbing - observation. This wrapper always returns `done=False`. + When the environment being wrapped returns `terminated=True` or `truncated=True`, + we return an absorbing observation. + This wrapper always returns `terminated=False` and `truncated=False`. A convenient way to add absorbing states to environments like MountainCar. """ @@ -238,15 +241,16 @@ def reset(self, *args, **kwargs): def step(self, action): """Advance the environment by one step. - This wrapped `step()` always returns done=False. + This wrapped `step()` always returns terminated=False and truncated=False. - After the first done is returned by the underlying Env, we enter an artificial - absorb state. + After the first time either terminated or truncated is returned by the + underlying Env, we enter an artificial absorb state. In this artificial absorb state, we stop calling `self.env.step(action)` (i.e. the `action` argument is entirely ignored) and - we return fixed values for obs, rew, done, and info. The values of `obs` and - `rew` depend on initialization arguments. `info` is always an empty dictionary. + we return fixed values for obs, rew, terminated, truncated, and info. + The values of `obs` and `rew` depend on initialization arguments. + `info` is always an empty dictionary. """ if not self.at_absorb_state: obs, rew, terminated, truncated, info = self.env.step( diff --git a/tests/test_base_env.py b/tests/test_base_env.py index c91fe22..62cad7b 100644 --- a/tests/test_base_env.py +++ b/tests/test_base_env.py @@ -102,7 +102,7 @@ def test_expose_pomdp_state_wrapper(): assert state in env.state_space action = env.action_space.sample() - next_state, reward, done, info = wrapped_env.step(action) + next_state, reward, terminated, truncated, info = wrapped_env.step(action) assert next_state == env.state assert next_state in env.state_space diff --git a/tests/test_wrappers.py b/tests/test_wrappers.py index 63e50c3..9a70988 100644 --- a/tests/test_wrappers.py +++ b/tests/test_wrappers.py @@ -32,7 +32,7 @@ def test_auto_reset_wrapper_pad(episode_length=3, n_steps=100, n_manual_reset=2) act = env.action_space.sample() obs, rew, terminated, truncated, info = env.step(act) - # AutoResetWrapper overrides all done signals. + # AutoResetWrapper overrides all terminated and truncated signals. assert terminated is False assert truncated is False From 57c63012c6b979841f1f26a5d8a9317485bc1f4e Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Fri, 11 Aug 2023 17:03:37 +0200 Subject: [PATCH 18/61] Remove outdated make_env_no_wrappers --- src/seals/classic_control.py | 4 +++- src/seals/util.py | 8 -------- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/src/seals/classic_control.py b/src/seals/classic_control.py index 0fddd71..4fd1249 100644 --- a/src/seals/classic_control.py +++ b/src/seals/classic_control.py @@ -3,6 +3,8 @@ import warnings import numpy as np + +import gymnasium as gym from gymnasium import spaces from gymnasium.envs import classic_control @@ -71,7 +73,7 @@ def mountain_car(): Done is always returned on timestep 200 only. """ - env = util.make_env_no_wrappers("MountainCar-v0") + env = gym.make("MountainCar-v0") env = util.ObsCastWrapper(env, dtype=np.float32) env = util.AbsorbAfterDoneWrapper(env) return env diff --git a/src/seals/util.py b/src/seals/util.py index 026b9f8..1a02522 100644 --- a/src/seals/util.py +++ b/src/seals/util.py @@ -275,14 +275,6 @@ def step(self, action): return obs, rew, False, False, info -def make_env_no_wrappers(env_name: str, **kwargs) -> gym.Env: - """Gym sometimes wraps envs in TimeLimit before returning from gymnasium.make(). - - This helper method builds directly from spec to avoid this wrapper. - """ - return gym.spec(env_name).make(**kwargs) - - def get_gym_max_episode_steps(env_name: str) -> Optional[int]: """Get the `max_episode_steps` attribute associated with a gym Spec.""" return gym.spec(env_name).max_episode_steps From 83fc29e7e212075bcab719ec1626f153756478da Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Fri, 11 Aug 2023 17:07:21 +0200 Subject: [PATCH 19/61] Use registry keys instead of extracting env_id from the spec. --- tests/test_envs.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/test_envs.py b/tests/test_envs.py index 275f823..751430d 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -10,9 +10,8 @@ from seals.testing import envs ENV_NAMES: List[str] = [ - env_spec.id - for env_spec in registration.registry.values() - if env_spec.id.startswith("seals/") + env_id for env_id in registration.registry.keys() + if env_id.startswith("seals/") ] From 5eba4c24693d1585de45a2f3898fcdd8fbe43ca6 Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Sun, 13 Aug 2023 15:06:14 +0200 Subject: [PATCH 20/61] Ensure to seed environments upon the first reset. --- src/seals/testing/envs.py | 4 ++-- tests/test_base_env.py | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/seals/testing/envs.py b/src/seals/testing/envs.py index b0b44a9..bfa322d 100644 --- a/src/seals/testing/envs.py +++ b/src/seals/testing/envs.py @@ -231,7 +231,7 @@ def test_rollout_schema( AssertionError if test fails. """ obs_space = env.observation_space - obs, _ = env.reset() + obs, _ = env.reset(seed=0) _check_obs(obs, obs_space) terminated = False @@ -289,7 +289,7 @@ def test_render(env: gym.Env, raises_fn) -> None: `env.metadata["render.modes"]` is empty; (c) `env.render(mode="rgb_array")` returns different values at the same time step. """ - env.reset() # make sure environment is in consistent state + env.reset(seed=0) # make sure environment is in consistent state render_modes = env.metadata["render.modes"] if not render_modes: diff --git a/tests/test_base_env.py b/tests/test_base_env.py index 62cad7b..c9b3936 100644 --- a/tests/test_base_env.py +++ b/tests/test_base_env.py @@ -35,7 +35,7 @@ def test_base_envs(): envs.test_premature_step(env, skip_fn=pytest.skip, raises_fn=pytest.raises) - env.reset() + env.reset(seed=0) assert env.n_actions_taken == 0 env.step(env.action_space.sample()) assert env.n_actions_taken == 1 @@ -86,7 +86,7 @@ def test_tabular_env_validation(): transition_matrix=np.zeros((3, 1, 3)), reward_matrix=np.zeros((3,)), ) - env.reset() + env.reset(seed=0) with pytest.raises(ValueError, match=r".*not in.*"): env.step(4) @@ -97,7 +97,7 @@ def test_expose_pomdp_state_wrapper(): wrapped_env = base_envs.ExposePOMDPStateWrapper(env) assert wrapped_env.observation_space == env.state_space - state = wrapped_env.reset() + state, _ = wrapped_env.reset(seed=0) assert state == env.state assert state in env.state_space From 599031a8251368634143bbb8a85bb2a884cd370e Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Sun, 13 Aug 2023 15:06:33 +0200 Subject: [PATCH 21/61] Add missing shimmy dependency for atari. --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 75c2d21..594c2b5 100644 --- a/setup.py +++ b/setup.py @@ -91,6 +91,7 @@ def get_readme() -> str: "ale-py==0.7.4", "pillow", "autorom[accept-rom-license]~=0.4.2", + "shimmy[atari] >=0.1.0,<1.0", ] TESTS_REQUIRE = [ # remove pin once https://github.com/nedbat/coveragepy/issues/881 fixed From 15f74941391f04487b0ca54aebbdb268b2ea5d88 Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Sun, 13 Aug 2023 15:06:57 +0200 Subject: [PATCH 22/61] Detect atari envs by looking for shimmy entrypoint instead of gym entrypoint. --- src/seals/atari.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/seals/atari.py b/src/seals/atari.py index 9da2cce..b4f99e7 100644 --- a/src/seals/atari.py +++ b/src/seals/atari.py @@ -67,7 +67,7 @@ def _not_ram_or_det(env_id: str) -> bool: def _supported_atari_env(gym_spec: EnvSpec) -> bool: """Checks if a gym Atari environment is one of the ones we will support.""" - is_atari = gym_spec.entry_point == "gym.envs.atari:AtariEnv" + is_atari = gym_spec.entry_point == "shimmy.atari_env:AtariEnv" v5_and_plain = gym_spec.id.endswith("-v5") and "NoFrameskip" not in gym_spec.id v4_and_no_frameskip = gym_spec.id.endswith("-v4") and "NoFrameskip" in gym_spec.id return ( From 2342e7256a43785eb3050f1f127eecdd15cb6aa2 Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Sun, 13 Aug 2023 15:07:15 +0200 Subject: [PATCH 23/61] Add missing observation space to TabularModelMDP --- src/seals/base_envs.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/seals/base_envs.py b/src/seals/base_envs.py index 448e2fb..7090358 100644 --- a/src/seals/base_envs.py +++ b/src/seals/base_envs.py @@ -435,6 +435,7 @@ def __init__( horizon=horizon, initial_state_dist=initial_state_dist, ) + self.observation_space = self.state_space def obs_from_state(self, state: DiscreteSpaceInt) -> DiscreteSpaceInt: """Identity since observation == state in an MDP.""" From dada630557600dbbf59c2598ee15dc0d9e66f82b Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Sun, 13 Aug 2023 15:11:49 +0200 Subject: [PATCH 24/61] Look for render modes in new location of the environment metadata. --- src/seals/testing/envs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/seals/testing/envs.py b/src/seals/testing/envs.py index bfa322d..338af48 100644 --- a/src/seals/testing/envs.py +++ b/src/seals/testing/envs.py @@ -291,7 +291,7 @@ def test_render(env: gym.Env, raises_fn) -> None: """ env.reset(seed=0) # make sure environment is in consistent state - render_modes = env.metadata["render.modes"] + render_modes = env.metadata["render_modes"] if not render_modes: # No modes supported -- render() should fail. with raises_fn(NotImplementedError): From 807d79b2f8688b7a33aa3aa48318137fc453d235 Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Sun, 13 Aug 2023 15:12:33 +0200 Subject: [PATCH 25/61] When testing the rollout schema, check for both termination and truncation. --- src/seals/testing/envs.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/seals/testing/envs.py b/src/seals/testing/envs.py index 338af48..d12d748 100644 --- a/src/seals/testing/envs.py +++ b/src/seals/testing/envs.py @@ -203,7 +203,7 @@ def _sample_and_check(env: gym.Env, obs_space: gym.Space) -> Tuple[bool, bool]: assert isinstance(terminated, bool) assert isinstance(truncated, bool) assert isinstance(info, dict) - return terminated, truncated + return terminated or truncated def _is_mujoco_env(env: gym.Env) -> bool: @@ -234,14 +234,14 @@ def test_rollout_schema( obs, _ = env.reset(seed=0) _check_obs(obs, obs_space) - terminated = False + done = False for _ in range(max_steps): - terminated, _ = _sample_and_check(env, obs_space) - if terminated: + done = _sample_and_check(env, obs_space) + if done: break if check_episode_ends: - assert terminated, "did not get to end of episode" + assert done, "did not get to end of episode" for _ in range(steps_after_terminated): _sample_and_check(env, obs_space) From 6bb994ef15d4284871639d8edf5fce62a25f4950 Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Sun, 13 Aug 2023 17:33:08 +0200 Subject: [PATCH 26/61] Remove outdated asserts on the result of env.reset(). --- src/seals/testing/envs.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/seals/testing/envs.py b/src/seals/testing/envs.py index d12d748..9e20f52 100644 --- a/src/seals/testing/envs.py +++ b/src/seals/testing/envs.py @@ -157,10 +157,7 @@ def test_seed( env.action_space.seed(0) actions = [env.action_space.sample() for _ in range(rollout_len)] # With the same seed, should always get the same result - seeds = env.reset(seed=42) - # output of env.seed should be a list, but atari environments return a tuple. - assert isinstance(seeds, (list, tuple)) - assert len(seeds) > 0 + env.reset(seed=42) rollout_a = get_rollout(env, actions) env.reset(seed=42) From 05dcf6721bf4e20c3a506bdf060b0a408795d06a Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Sun, 13 Aug 2023 17:34:10 +0200 Subject: [PATCH 27/61] Adapt reset() of MaskScroeWrapper to new gymnasium API --- src/seals/util.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/seals/util.py b/src/seals/util.py index 1a02522..1e45142 100644 --- a/src/seals/util.py +++ b/src/seals/util.py @@ -168,8 +168,8 @@ def step( def reset(self, **kwargs): """Returns masked reset observation.""" - obs = self.env.reset(**kwargs) - return self._mask_obs(obs) + obs, info = self.env.reset(**kwargs) + return self._mask_obs(obs), info class ObsCastWrapper(gym.Wrapper): From c7d8c1a6231c919ed42be0343dbb49a299e76b11 Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Mon, 14 Aug 2023 18:57:32 +0200 Subject: [PATCH 28/61] Switch to v4 versions of the MuJoCo environments. --- src/seals/__init__.py | 4 ++-- src/seals/mujoco.py | 24 ++++++++++++------------ tests/test_mujoco_rl.py | 4 ++-- 3 files changed, 16 insertions(+), 16 deletions(-) diff --git a/src/seals/__init__.py b/src/seals/__init__.py index fc15b21..2f1fa40 100644 --- a/src/seals/__init__.py +++ b/src/seals/__init__.py @@ -31,9 +31,9 @@ for env_base in ["Ant", "HalfCheetah", "Hopper", "Humanoid", "Swimmer", "Walker2d"]: gym.register( - id=f"seals/{env_base}-v0", + id=f"seals/{env_base}-v1", entry_point=f"seals.mujoco:{env_base}Env", - max_episode_steps=util.get_gym_max_episode_steps(f"{env_base}-v3"), + max_episode_steps=util.get_gym_max_episode_steps(f"{env_base}-v4"), ) # Atari diff --git a/src/seals/mujoco.py b/src/seals/mujoco.py index 0e539ab..3ffb5f3 100644 --- a/src/seals/mujoco.py +++ b/src/seals/mujoco.py @@ -3,12 +3,12 @@ import functools from gymnasium.envs.mujoco import ( - ant_v3, - half_cheetah_v3, - hopper_v3, - humanoid_v3, - swimmer_v3, - walker2d_v3, + ant_v4, + half_cheetah_v4, + hopper_v4, + humanoid_v4, + swimmer_v4, + walker2d_v4, ) @@ -27,33 +27,33 @@ def _no_early_termination(cls): @_include_position_in_observation @_no_early_termination -class AntEnv(ant_v3.AntEnv): +class AntEnv(ant_v4.AntEnv): """Ant with position observation and no early termination.""" @_include_position_in_observation -class HalfCheetahEnv(half_cheetah_v3.HalfCheetahEnv): +class HalfCheetahEnv(half_cheetah_v4.HalfCheetahEnv): """HalfCheetah with position observation. Naturally does not terminate early.""" @_include_position_in_observation @_no_early_termination -class HopperEnv(hopper_v3.HopperEnv): +class HopperEnv(hopper_v4.HopperEnv): """Hopper with position observation and no early termination.""" @_include_position_in_observation @_no_early_termination -class HumanoidEnv(humanoid_v3.HumanoidEnv): +class HumanoidEnv(humanoid_v4.HumanoidEnv): """Humanoid with position observation and no early termination.""" @_include_position_in_observation -class SwimmerEnv(swimmer_v3.SwimmerEnv): +class SwimmerEnv(swimmer_v4.SwimmerEnv): """Swimmer with position observation. Naturally does not terminate early.""" @_include_position_in_observation @_no_early_termination -class Walker2dEnv(walker2d_v3.Walker2dEnv): +class Walker2dEnv(walker2d_v4.Walker2dEnv): """Walker2d with position observation and no early termination.""" diff --git a/tests/test_mujoco_rl.py b/tests/test_mujoco_rl.py index 55f66b6..13a6304 100644 --- a/tests/test_mujoco_rl.py +++ b/tests/test_mujoco_rl.py @@ -35,9 +35,9 @@ def test_fixed_env_model_as_good_as_gym_env_model(env_base: str): # pragma: no """Compare original and modified MuJoCo v3 envs.""" train_timesteps = 200000 - gym_reward, _ = _eval_env(f"{env_base}-v3", total_timesteps=train_timesteps) + gym_reward, _ = _eval_env(f"{env_base}-v4", total_timesteps=train_timesteps) fixed_reward, _ = _eval_env( - f"seals/{env_base}-v0", + f"seals/{env_base}-v1", total_timesteps=train_timesteps, ) From e73e209752afa04bf4c45643d1e9101b4906c6ca Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Mon, 14 Aug 2023 19:08:27 +0200 Subject: [PATCH 29/61] Simplify tests for render modes. --- src/seals/testing/envs.py | 47 ++------------------------------------- tests/test_envs.py | 31 +++++++++++++++++++++++--- 2 files changed, 30 insertions(+), 48 deletions(-) diff --git a/src/seals/testing/envs.py b/src/seals/testing/envs.py index 9e20f52..989f534 100644 --- a/src/seals/testing/envs.py +++ b/src/seals/testing/envs.py @@ -203,7 +203,7 @@ def _sample_and_check(env: gym.Env, obs_space: gym.Space) -> Tuple[bool, bool]: return terminated or truncated -def _is_mujoco_env(env: gym.Env) -> bool: +def is_mujoco_env(env: gym.Env) -> bool: return hasattr(env, "sim") and hasattr(env, "model") @@ -258,7 +258,7 @@ def test_premature_step(env: gym.Env, skip_fn, raises_fn) -> None: Raises: AssertionError if test fails. """ - if _is_mujoco_env(env): # pragma: no cover + if is_mujoco_env(env): # pragma: no cover # We can't use isinstance since importing mujoco_py will fail on # machines without MuJoCo installed skip_fn("MuJoCo environments cannot perform this check.") @@ -268,49 +268,6 @@ def test_premature_step(env: gym.Env, skip_fn, raises_fn) -> None: env.step(act) -def test_render(env: gym.Env, raises_fn) -> None: - """Test that render() supports the modes declared. - - Example usage in pytest: - test_render(env, raises_fn=pytest.raises) - - Args: - env: The environment to test. - raises_fn: Context manager to check NotImplementedError is thrown when - environment metadata indicates modes are supported. - - Raises: - AssertionError: if test fails. This occurs if: - (a) `env.render(mode=mode)` fails for any mode declared supported - in `env.metadata["render.modes"]`; (b) env.render() *succeeds* when - `env.metadata["render.modes"]` is empty; (c) `env.render(mode="rgb_array")` - returns different values at the same time step. - """ - env.reset(seed=0) # make sure environment is in consistent state - - render_modes = env.metadata["render_modes"] - if not render_modes: - # No modes supported -- render() should fail. - with raises_fn(NotImplementedError): - env.render() - else: - for mode in render_modes: - env.render(mode=mode) - - # WARNING(adam): there seems to be a memory leak with Gym 0.17.3 - # & MuJoCoPy 1.50.1.68. `MujocoEnv.close()` does not call `finish()` - # on the viewer (commented out) so the resources are not released. - # For now this is OK, but may bite if we end up testing a lot of - # MuJoCo environments. - is_mujoco = _is_mujoco_env(env) - if "rgb_array" in render_modes and not is_mujoco: - # Render should not change without calling `step()`. - # MuJoCo rendering fails this check, ignore -- not much we can do. - resa = env.render(mode="rgb_array") - resb = env.render(mode="rgb_array") - assert np.allclose(resa, resb) - - class CountingEnv(gym.Env): """At timestep `t` of each episode, has `t == obs == reward / 10`. diff --git a/tests/test_envs.py b/tests/test_envs.py index 751430d..0925cee 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -3,11 +3,13 @@ from typing import List import gymnasium as gym +import numpy as np import pytest import seals # noqa: F401 required for env registration from gymnasium.envs import registration from seals.atari import SCORE_REGIONS, _get_score_region, _seals_name, make_atari_env from seals.testing import envs +from seals.testing.envs import is_mujoco_env ENV_NAMES: List[str] = [ env_id for env_id in registration.registry.keys() @@ -147,6 +149,29 @@ def test_rollout_schema(self, env: gym.Env, env_name: str): else: envs.test_rollout_schema(env) - def test_render(self, env: gym.Env): - """Tests `render()` supports modes specified in environment metadata.""" - envs.test_render(env, raises_fn=pytest.raises) + def test_render_modes(self, env_name: str): + """Tests that all render modes specifeid in the metadata work. + + Note: we only check that no exception is thrown. + There is no test to see if something reasonable is rendered. + """ + for mode in gym.make(env_name).metadata["render_modes"]: + # GIVEN + env = gym.make(env_name, render_mode=mode) + env.reset(seed=0) + + # WHEN + if mode == "rgb_array" and not is_mujoco_env(env): + # The render should not change without calling `step()`. + # MuJoCo rendering fails this check, ignore -- not much we can do. + r1 = env.render() + r2 = env.render() + assert np.allclose(r1, r2) + else: + env.render() + + # THEN + # no error raised + + # CLEANUP + env.close() From 1650590fea8fc147cb8795c14658510c14e7d139 Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Mon, 14 Aug 2023 19:09:30 +0200 Subject: [PATCH 30/61] Forward args and kwargs when constructing environments, so we can pass in the render mode. --- src/seals/atari.py | 4 ++-- src/seals/classic_control.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/seals/atari.py b/src/seals/atari.py index b4f99e7..3782c84 100644 --- a/src/seals/atari.py +++ b/src/seals/atari.py @@ -36,9 +36,9 @@ def _get_score_region(atari_env_id: str) -> Optional[MaskedRegionSpecifier]: return SCORE_REGIONS.get(basename) -def make_atari_env(atari_env_id: str, masked: bool) -> gym.Env: +def make_atari_env(atari_env_id: str, masked: bool, *args, **kwargs) -> gym.Env: """Fixed-length, optionally masked-score variant of a given Atari environment.""" - env: gym.Env = AutoResetWrapper(gym.make(atari_env_id)) + env: gym.Env = AutoResetWrapper(gym.make(atari_env_id, *args, **kwargs)) if masked: score_region = _get_score_region(atari_env_id) diff --git a/src/seals/classic_control.py b/src/seals/classic_control.py index 4fd1249..98ceb15 100644 --- a/src/seals/classic_control.py +++ b/src/seals/classic_control.py @@ -23,9 +23,9 @@ class FixedHorizonCartPole(classic_control.CartPoleEnv): which sets `truncated` to true after that many steps. """ - def __init__(self): + def __init__(self, *args, **kwargs): """Builds FixedHorizonCartPole, modifying observation_space from gym parent.""" - super().__init__() + super().__init__(*args, **kwargs) high = [ np.finfo(np.float32).max, # x axis @@ -64,7 +64,7 @@ def step(self, action): return np.array(self.state, dtype=np.float32), rew, False, False, {} -def mountain_car(): +def mountain_car(*args, **kwargs): """Fixed-length variant of MountainCar-v0. In the event of early episode completion (i.e., the car reaches the @@ -73,7 +73,7 @@ def mountain_car(): Done is always returned on timestep 200 only. """ - env = gym.make("MountainCar-v0") + env = gym.make("MountainCar-v0", *args, **kwargs) env = util.ObsCastWrapper(env, dtype=np.float32) env = util.AbsorbAfterDoneWrapper(env) return env From 602f0a9f0e3a24ef152bd712205d5418603215ed Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Mon, 14 Aug 2023 19:10:02 +0200 Subject: [PATCH 31/61] Add `Casino-Unmasked-v5` to the list of slow envs with randomness. --- tests/test_envs.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_envs.py b/tests/test_envs.py index 0925cee..fa34e91 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -120,6 +120,7 @@ def test_seed(self, env: gym.Env, env_name: str): "seals/KingKong-Unmasked-v5", "seals/Koolaid-Unmasked-v5", "seals/NameThisGame-Unmasked-v5", + "seals/Casino-Unmasked-v5", ] rollout_len = 100 if env_name not in slow_random_envs else 400 num_seeds = 2 if env_name in ATARI_NO_FRAMESKIP_ENVS else 10 From b739176952455fc6dd5626a4fd5d55bd61756d19 Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Mon, 14 Aug 2023 19:10:14 +0200 Subject: [PATCH 32/61] Add some missing commas. --- tests/test_wrappers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_wrappers.py b/tests/test_wrappers.py index 9a70988..4be6b96 100644 --- a/tests/test_wrappers.py +++ b/tests/test_wrappers.py @@ -12,8 +12,8 @@ def test_auto_reset_wrapper_pad(episode_length=3, n_steps=100, n_manual_reset=2) AutoResetWrapper that pads trajectory with an extra transition containing the terminal observations. Also check that calls to .reset() do not interfere with automatic resets. - Due to the padding the number of steps counted inside the environment and the number - of steps performed outside the environment, i.e. the number of actions performed, + Due to the padding, the number of steps counted inside the environment and the number + of steps performed outside the environment, i.e., the number of actions performed, will differ. This test checks that this difference is consistent. """ env = util.AutoResetWrapper( From 2dce7ae29586da611f5f6c4476489314f5df5fc0 Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Mon, 14 Aug 2023 19:19:04 +0200 Subject: [PATCH 33/61] Add pygame to setup.py --- setup.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/setup.py b/setup.py index 594c2b5..3160ca1 100644 --- a/setup.py +++ b/setup.py @@ -117,6 +117,10 @@ def get_readme() -> str: "pytype", "stable-baselines3>=0.9.0", "setuptools_scm~=7.0.5", + # We'd like to specify `gymnasium[classic-control]`, but this is a no-op when + # gymnasium is already installed. See https://github.com/pypa/pip/issues/4957 for + # issue. + "pygame>=2.1.3", *ATARI_REQUIRE, ] DOCS_REQUIRE = [ @@ -144,8 +148,8 @@ def get_readme() -> str: "dev": ["ipdb", "jupyter", *TESTS_REQUIRE, *DOCS_REQUIRE], "docs": DOCS_REQUIRE, "test": TESTS_REQUIRE, - # We'd like to specify `gymnasium[mujoco]`, but this is a no-op when Gym is already - # installed. See https://github.com/pypa/pip/issues/4957 for issue. + # We'd like to specify `gymnasium[mujoco]`, but this is a no-op when gymnasium + # is already installed. See https://github.com/pypa/pip/issues/4957 for issue. "mujoco": ["mujoco", "imageio"], "atari": ATARI_REQUIRE, }, From 72da2f181dfefe50b4bfec2abee68d692b8f27fa Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Mon, 14 Aug 2023 19:29:51 +0200 Subject: [PATCH 34/61] Update ale-py version. --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 3160ca1..a31f9f1 100644 --- a/setup.py +++ b/setup.py @@ -88,7 +88,7 @@ def get_readme() -> str: ATARI_REQUIRE = [ "opencv-python", - "ale-py==0.7.4", + "ale-py~=0.8.1", "pillow", "autorom[accept-rom-license]~=0.4.2", "shimmy[atari] >=0.1.0,<1.0", From 062e31d19fd0c1f475695941e685bd2bfc67c23e Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Mon, 14 Aug 2023 19:42:00 +0200 Subject: [PATCH 35/61] Make test_sample_distribution deterministic by introducing a seed. --- tests/test_util.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_util.py b/tests/test_util.py index 9ed8ceb..465af4b 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -20,6 +20,7 @@ def test_mask_score_wrapper_enforces_spec(): def test_sample_distribution(): """Test util.sample_distribution.""" + np.random.seed(0) distr_size = 5 distr = np.random.random((distr_size,)) distr /= distr.sum() From 4ccc8844c6176889a1a1f54828b79a65ccdd5d4e Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Mon, 14 Aug 2023 19:46:15 +0200 Subject: [PATCH 36/61] Fixing isort issues. --- src/seals/__init__.py | 2 +- src/seals/base_envs.py | 4 ++-- src/seals/classic_control.py | 4 +--- src/seals/diagnostics/largest_sum.py | 2 +- src/seals/diagnostics/noisy_obs.py | 2 +- src/seals/diagnostics/parabola.py | 2 +- src/seals/diagnostics/proc_goal.py | 2 +- src/seals/diagnostics/sort.py | 2 +- src/seals/util.py | 14 ++++++++++++-- tests/test_base_env.py | 1 + tests/test_envs.py | 3 ++- tests/test_util.py | 1 + tests/test_wrappers.py | 1 + 13 files changed, 26 insertions(+), 14 deletions(-) diff --git a/src/seals/__init__.py b/src/seals/__init__.py index 2f1fa40..5dd0f50 100644 --- a/src/seals/__init__.py +++ b/src/seals/__init__.py @@ -4,8 +4,8 @@ import gymnasium as gym -import seals.diagnostics # noqa: F401 from seals import atari, util +import seals.diagnostics # noqa: F401 try: __version__ = metadata.version("seals") diff --git a/src/seals/base_envs.py b/src/seals/base_envs.py index 7090358..6d0003f 100644 --- a/src/seals/base_envs.py +++ b/src/seals/base_envs.py @@ -1,12 +1,12 @@ """Base environment classes.""" import abc -from typing import Any, Generic, Optional, Tuple, TypeVar, Dict, Union +from typing import Any, Dict, Generic, Optional, Tuple, TypeVar, Union import gymnasium as gym +from gymnasium import spaces import numpy as np import numpy.typing as npt -from gymnasium import spaces from seals import util diff --git a/src/seals/classic_control.py b/src/seals/classic_control.py index 98ceb15..dc1ed8f 100644 --- a/src/seals/classic_control.py +++ b/src/seals/classic_control.py @@ -2,12 +2,10 @@ import warnings -import numpy as np - import gymnasium as gym from gymnasium import spaces from gymnasium.envs import classic_control - +import numpy as np from seals import util diff --git a/src/seals/diagnostics/largest_sum.py b/src/seals/diagnostics/largest_sum.py index 688588c..35a03f7 100644 --- a/src/seals/diagnostics/largest_sum.py +++ b/src/seals/diagnostics/largest_sum.py @@ -1,7 +1,7 @@ """Environment testing scalability to high-dimensional tasks.""" -import numpy as np from gymnasium import spaces +import numpy as np from seals import base_envs diff --git a/src/seals/diagnostics/noisy_obs.py b/src/seals/diagnostics/noisy_obs.py index e594c70..5fd8730 100644 --- a/src/seals/diagnostics/noisy_obs.py +++ b/src/seals/diagnostics/noisy_obs.py @@ -1,7 +1,7 @@ """Environment testing for robustness to noise.""" -import numpy as np from gymnasium import spaces +import numpy as np from seals import base_envs, util diff --git a/src/seals/diagnostics/parabola.py b/src/seals/diagnostics/parabola.py index 053eff1..47483ed 100644 --- a/src/seals/diagnostics/parabola.py +++ b/src/seals/diagnostics/parabola.py @@ -1,7 +1,7 @@ """Environment testing for generalization in continuous spaces.""" -import numpy as np from gymnasium import spaces +import numpy as np from seals import base_envs diff --git a/src/seals/diagnostics/proc_goal.py b/src/seals/diagnostics/proc_goal.py index 706a939..d212f44 100644 --- a/src/seals/diagnostics/proc_goal.py +++ b/src/seals/diagnostics/proc_goal.py @@ -1,7 +1,7 @@ """Large gridworld with random agent and goal position.""" -import numpy as np from gymnasium import spaces +import numpy as np from seals import base_envs, util diff --git a/src/seals/diagnostics/sort.py b/src/seals/diagnostics/sort.py index 6b1a88e..0ae8621 100644 --- a/src/seals/diagnostics/sort.py +++ b/src/seals/diagnostics/sort.py @@ -1,7 +1,7 @@ """Environment to sort a list using swap actions.""" -import numpy as np from gymnasium import spaces +import numpy as np from seals import base_envs diff --git a/src/seals/util.py b/src/seals/util.py index 1e45142..8afaba3 100644 --- a/src/seals/util.py +++ b/src/seals/util.py @@ -1,12 +1,22 @@ """Miscellaneous utilities.""" from dataclasses import dataclass -from typing import Any, Generic, List, Optional, Sequence, SupportsFloat, Tuple, Union, Dict +from typing import ( + Any, + Dict, + Generic, + List, + Optional, + Sequence, + SupportsFloat, + Tuple, + Union, +) import gymnasium as gym +from gymnasium.core import ActType, ObsType, WrapperActType, WrapperObsType import numpy as np import numpy.typing as npt -from gymnasium.core import ActType, ObsType, WrapperActType, WrapperObsType class AutoResetWrapper( diff --git a/tests/test_base_env.py b/tests/test_base_env.py index c9b3936..9b0ef05 100644 --- a/tests/test_base_env.py +++ b/tests/test_base_env.py @@ -7,6 +7,7 @@ import gymnasium as gym import numpy as np import pytest + from seals import base_envs from seals.testing import envs diff --git a/tests/test_envs.py b/tests/test_envs.py index fa34e91..4c6b3f9 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -3,10 +3,11 @@ from typing import List import gymnasium as gym +from gymnasium.envs import registration import numpy as np import pytest + import seals # noqa: F401 required for env registration -from gymnasium.envs import registration from seals.atari import SCORE_REGIONS, _get_score_region, _seals_name, make_atari_env from seals.testing import envs from seals.testing.envs import is_mujoco_env diff --git a/tests/test_util.py b/tests/test_util.py index 465af4b..1a569a0 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -5,6 +5,7 @@ import gymnasium as gym import numpy as np import pytest + from seals import GYM_ATARI_ENV_SPECS, util diff --git a/tests/test_wrappers.py b/tests/test_wrappers.py index 4be6b96..ac1c583 100644 --- a/tests/test_wrappers.py +++ b/tests/test_wrappers.py @@ -2,6 +2,7 @@ import numpy as np import pytest + from seals import util from seals.testing import envs From 4c2379c86134ced1be1d39cc588aeb9ac1d490df Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Mon, 14 Aug 2023 19:48:57 +0200 Subject: [PATCH 37/61] Add missing trailing commas. --- src/seals/base_envs.py | 14 +++++++------- src/seals/diagnostics/random_trans.py | 2 +- src/seals/util.py | 14 ++++++-------- 3 files changed, 14 insertions(+), 16 deletions(-) diff --git a/src/seals/base_envs.py b/src/seals/base_envs.py index 6d0003f..ff299c9 100644 --- a/src/seals/base_envs.py +++ b/src/seals/base_envs.py @@ -16,7 +16,7 @@ class ResettablePOMDP( - gym.Env[ObsType, ActType], abc.ABC, Generic[StateType, ObsType, ActType] + gym.Env[ObsType, ActType], abc.ABC, Generic[StateType, ObsType, ActType], ): """ABC for POMDPs that are resettable. @@ -139,7 +139,7 @@ def __init__(self, env: ResettablePOMDP[StateType, ObsType, ActType]) -> None: self._observation_space = env.state_space def reset( - self, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None + self, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None, ) -> Tuple[StateType, Dict[str, Any]]: """Reset environment and return initial state.""" _, info = self.env.reset(seed=seed, options=options) @@ -180,7 +180,7 @@ def obs_from_state(self, state: StateType) -> StateType: # so in theory it should not be instantiated directly. # Not sure why this is not raising an error? class BaseTabularModelPOMDP( - ResettablePOMDP[DiscreteSpaceInt, ObsType, DiscreteSpaceInt], Generic[ObsType] + ResettablePOMDP[DiscreteSpaceInt, ObsType, DiscreteSpaceInt], Generic[ObsType], ): """Base class for tabular environments with known dynamics. @@ -275,18 +275,18 @@ def initial_state(self) -> DiscreteSpaceInt: util.sample_distribution( self.initial_state_dist, random=self.rand_state, - ) + ), ) def transition( - self, state: DiscreteSpaceInt, action: DiscreteSpaceInt + self, state: DiscreteSpaceInt, action: DiscreteSpaceInt, ) -> DiscreteSpaceInt: """Samples from transition distribution.""" return DiscreteSpaceInt( util.sample_distribution( self.transition_matrix[state, action], random=self.rand_state, - ) + ), ) def reward( @@ -325,7 +325,7 @@ def action_dim(self) -> int: ObsEntryType = TypeVar( - "ObsEntryType", bound=Union[np.floating, np.integer], covariant=True + "ObsEntryType", bound=Union[np.floating, np.integer], covariant=True, ) diff --git a/src/seals/diagnostics/random_trans.py b/src/seals/diagnostics/random_trans.py index ec8f067..37444df 100644 --- a/src/seals/diagnostics/random_trans.py +++ b/src/seals/diagnostics/random_trans.py @@ -74,7 +74,7 @@ def __init__( ) self.reward_weights = rand_gen.normal( - 0, 1, size=(observation_matrix.shape[-1],) + 0, 1, size=(observation_matrix.shape[-1],), ) reward_matrix = observation_matrix @ self.reward_weights super().__init__( diff --git a/src/seals/util.py b/src/seals/util.py index 8afaba3..a1f2d9e 100644 --- a/src/seals/util.py +++ b/src/seals/util.py @@ -20,7 +20,7 @@ class AutoResetWrapper( - gym.Wrapper, Generic[WrapperObsType, WrapperActType, ObsType, ActType] + gym.Wrapper, Generic[WrapperObsType, WrapperActType, ObsType, ActType], ): """Hides terminated=True and truncated=True and auto-resets at the end of each episode. @@ -53,7 +53,7 @@ def __init__(self, env, discard_terminal_observation=True, reset_reward=0.0): self.previous_done = False # Whether the previous step returned done=True. def step( - self, action: WrapperActType + self, action: WrapperActType, ) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]: """When terminated or truncated, resets the environment and returns False for terminated and truncated. @@ -69,7 +69,7 @@ def step( return self._step_pad(action) def _step_pad( - self, action: WrapperActType + self, action: WrapperActType, ) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]: """When terminated or truncated, return False for both instead and return the terminal obs. @@ -98,7 +98,7 @@ def _step_pad( return obs, rew, False, False, info def _step_discard( - self, action: WrapperActType + self, action: WrapperActType, ) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]: """When terminated or truncated, return False for both and automatically reset. @@ -170,7 +170,7 @@ def _mask_obs(self, obs) -> npt.NDArray: return np.where(self.mask, obs, self.fill_value) def step( - self, action: ActType + self, action: ActType, ) -> Tuple[npt.NDArray, SupportsFloat, bool, bool, Dict[str, Any]]: """Returns (obs, rew, terminated, truncated, info) with masked obs.""" obs, rew, terminated, truncated, info = self.env.step(action) @@ -263,9 +263,7 @@ def step(self, action): `info` is always an empty dictionary. """ if not self.at_absorb_state: - obs, rew, terminated, truncated, info = self.env.step( - action - ) + obs, rew, terminated, truncated, info = self.env.step(action) if terminated or truncated: # Initialize the artificial absorb state, which we will repeatedly use # starting on the next call to `step()`. From 96b27e30c8abf6fe019d3bf90922e7f181b048af Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Mon, 14 Aug 2023 19:53:37 +0200 Subject: [PATCH 38/61] Minor formatting fixes. --- src/seals/base_envs.py | 1 - src/seals/diagnostics/random_trans.py | 2 +- src/seals/testing/envs.py | 4 +++- src/seals/util.py | 15 ++++++++------- tests/test_wrappers.py | 6 +++--- 5 files changed, 15 insertions(+), 13 deletions(-) diff --git a/src/seals/base_envs.py b/src/seals/base_envs.py index ff299c9..79d0252 100644 --- a/src/seals/base_envs.py +++ b/src/seals/base_envs.py @@ -33,7 +33,6 @@ class ResettablePOMDP( def __init__(self): """Build resettable (PO)MDP.""" - self._cur_state = None self._n_actions_taken = None diff --git a/src/seals/diagnostics/random_trans.py b/src/seals/diagnostics/random_trans.py index 37444df..d539fb5 100644 --- a/src/seals/diagnostics/random_trans.py +++ b/src/seals/diagnostics/random_trans.py @@ -72,7 +72,7 @@ def __init__( n_states=n_states, rand_state=rand_gen, ) - + self.reward_weights = rand_gen.normal( 0, 1, size=(observation_matrix.shape[-1],), ) diff --git a/src/seals/testing/envs.py b/src/seals/testing/envs.py index 989f534..8f21245 100644 --- a/src/seals/testing/envs.py +++ b/src/seals/testing/envs.py @@ -22,7 +22,8 @@ Step = Tuple[Any, Optional[float], bool, bool, Mapping[str, Any]] Rollout = Sequence[Step] -"""A sequence of 4-tuples (obs, rew, terminated, truncated, info) as returned by `get_rollout`.""" +"""A sequence of 5-tuples (obs, rew, terminated, truncated, info) as returned by +`get_rollout`.""" def make_env_fixture( @@ -204,6 +205,7 @@ def _sample_and_check(env: gym.Env, obs_space: gym.Space) -> Tuple[bool, bool]: def is_mujoco_env(env: gym.Env) -> bool: + """True if `env` is a MuJoCo environment.""" return hasattr(env, "sim") and hasattr(env, "model") diff --git a/src/seals/util.py b/src/seals/util.py index a1f2d9e..c371617 100644 --- a/src/seals/util.py +++ b/src/seals/util.py @@ -22,8 +22,7 @@ class AutoResetWrapper( gym.Wrapper, Generic[WrapperObsType, WrapperActType, ObsType, ActType], ): - """Hides terminated=True and truncated=True and auto-resets at the end of each - episode. + """Hides terminated truncated and auto-resets at the end of each episode. Depending on the flag 'discard_terminal_observation', either discards the terminal observation or pads with an additional 'reset transition'. The former is the default @@ -55,8 +54,9 @@ def __init__(self, env, discard_terminal_observation=True, reset_reward=0.0): def step( self, action: WrapperActType, ) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]: - """When terminated or truncated, resets the environment and returns False - for terminated and truncated. + """When terminated or truncated, resets the environment. + + Always returns False for terminated and truncated. Depending on whether we are discarding the terminal observation, either resets the environment and discards, @@ -71,8 +71,9 @@ def step( def _step_pad( self, action: WrapperActType, ) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]: - """When terminated or truncated, return False for both instead and return the - terminal obs. + """When terminated or truncated, resets the environment. + + Always returns False for terminated and truncated. The agent will then usually be asked to perform an action based on the terminal observation. In the next step, this final action will be ignored @@ -207,7 +208,7 @@ def reset(self, seed=None): return obs.astype(self.dtype), info def step(self, action): - """Returns (obs, rew, terminated, truncated, info) with obs cast to self.dtype.""" + """Returns (obs, rew, terminated, truncated, info) with obs cast to dtype.""" obs, rew, terminated, truncated, info = super().step(action) return obs.astype(self.dtype), rew, terminated, truncated, info diff --git a/tests/test_wrappers.py b/tests/test_wrappers.py index ac1c583..f7f5b26 100644 --- a/tests/test_wrappers.py +++ b/tests/test_wrappers.py @@ -13,9 +13,9 @@ def test_auto_reset_wrapper_pad(episode_length=3, n_steps=100, n_manual_reset=2) AutoResetWrapper that pads trajectory with an extra transition containing the terminal observations. Also check that calls to .reset() do not interfere with automatic resets. - Due to the padding, the number of steps counted inside the environment and the number - of steps performed outside the environment, i.e., the number of actions performed, - will differ. This test checks that this difference is consistent. + Due to the padding, the number of steps counted inside the environment and the + number of steps performed outside the environment, i.e., the number of actions + performed, will differ. This test checks that this difference is consistent. """ env = util.AutoResetWrapper( envs.CountingEnv(episode_length=episode_length), From 85bbce220e79814c7d6ba95f590b2bcfd7c580b0 Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Mon, 14 Aug 2023 19:55:06 +0200 Subject: [PATCH 39/61] Fix trailing whitespace. --- src/seals/testing/envs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/seals/testing/envs.py b/src/seals/testing/envs.py index 8f21245..3f20879 100644 --- a/src/seals/testing/envs.py +++ b/src/seals/testing/envs.py @@ -22,7 +22,7 @@ Step = Tuple[Any, Optional[float], bool, bool, Mapping[str, Any]] Rollout = Sequence[Step] -"""A sequence of 5-tuples (obs, rew, terminated, truncated, info) as returned by +"""A sequence of 5-tuples (obs, rew, terminated, truncated, info) as returned by `get_rollout`.""" From ab4f18263636783613453ab9e4f53c4ee6e7f12c Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Mon, 14 Aug 2023 19:59:48 +0200 Subject: [PATCH 40/61] Black fixes. --- src/seals/base_envs.py | 19 ++++++++++++++----- src/seals/diagnostics/random_trans.py | 4 +++- src/seals/util.py | 15 ++++++++++----- tests/test_envs.py | 3 +-- 4 files changed, 28 insertions(+), 13 deletions(-) diff --git a/src/seals/base_envs.py b/src/seals/base_envs.py index 79d0252..03a83f0 100644 --- a/src/seals/base_envs.py +++ b/src/seals/base_envs.py @@ -16,7 +16,9 @@ class ResettablePOMDP( - gym.Env[ObsType, ActType], abc.ABC, Generic[StateType, ObsType, ActType], + gym.Env[ObsType, ActType], + abc.ABC, + Generic[StateType, ObsType, ActType], ): """ABC for POMDPs that are resettable. @@ -138,7 +140,9 @@ def __init__(self, env: ResettablePOMDP[StateType, ObsType, ActType]) -> None: self._observation_space = env.state_space def reset( - self, seed: Optional[int] = None, options: Optional[Dict[str, Any]] = None, + self, + seed: Optional[int] = None, + options: Optional[Dict[str, Any]] = None, ) -> Tuple[StateType, Dict[str, Any]]: """Reset environment and return initial state.""" _, info = self.env.reset(seed=seed, options=options) @@ -179,7 +183,8 @@ def obs_from_state(self, state: StateType) -> StateType: # so in theory it should not be instantiated directly. # Not sure why this is not raising an error? class BaseTabularModelPOMDP( - ResettablePOMDP[DiscreteSpaceInt, ObsType, DiscreteSpaceInt], Generic[ObsType], + ResettablePOMDP[DiscreteSpaceInt, ObsType, DiscreteSpaceInt], + Generic[ObsType], ): """Base class for tabular environments with known dynamics. @@ -278,7 +283,9 @@ def initial_state(self) -> DiscreteSpaceInt: ) def transition( - self, state: DiscreteSpaceInt, action: DiscreteSpaceInt, + self, + state: DiscreteSpaceInt, + action: DiscreteSpaceInt, ) -> DiscreteSpaceInt: """Samples from transition distribution.""" return DiscreteSpaceInt( @@ -324,7 +331,9 @@ def action_dim(self) -> int: ObsEntryType = TypeVar( - "ObsEntryType", bound=Union[np.floating, np.integer], covariant=True, + "ObsEntryType", + bound=Union[np.floating, np.integer], + covariant=True, ) diff --git a/src/seals/diagnostics/random_trans.py b/src/seals/diagnostics/random_trans.py index d539fb5..45f717b 100644 --- a/src/seals/diagnostics/random_trans.py +++ b/src/seals/diagnostics/random_trans.py @@ -74,7 +74,9 @@ def __init__( ) self.reward_weights = rand_gen.normal( - 0, 1, size=(observation_matrix.shape[-1],), + 0, + 1, + size=(observation_matrix.shape[-1],), ) reward_matrix = observation_matrix @ self.reward_weights super().__init__( diff --git a/src/seals/util.py b/src/seals/util.py index c371617..2843e41 100644 --- a/src/seals/util.py +++ b/src/seals/util.py @@ -20,7 +20,8 @@ class AutoResetWrapper( - gym.Wrapper, Generic[WrapperObsType, WrapperActType, ObsType, ActType], + gym.Wrapper, + Generic[WrapperObsType, WrapperActType, ObsType, ActType], ): """Hides terminated truncated and auto-resets at the end of each episode. @@ -52,7 +53,8 @@ def __init__(self, env, discard_terminal_observation=True, reset_reward=0.0): self.previous_done = False # Whether the previous step returned done=True. def step( - self, action: WrapperActType, + self, + action: WrapperActType, ) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]: """When terminated or truncated, resets the environment. @@ -69,7 +71,8 @@ def step( return self._step_pad(action) def _step_pad( - self, action: WrapperActType, + self, + action: WrapperActType, ) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]: """When terminated or truncated, resets the environment. @@ -99,7 +102,8 @@ def _step_pad( return obs, rew, False, False, info def _step_discard( - self, action: WrapperActType, + self, + action: WrapperActType, ) -> Tuple[ObsType, SupportsFloat, bool, bool, Dict[str, Any]]: """When terminated or truncated, return False for both and automatically reset. @@ -171,7 +175,8 @@ def _mask_obs(self, obs) -> npt.NDArray: return np.where(self.mask, obs, self.fill_value) def step( - self, action: ActType, + self, + action: ActType, ) -> Tuple[npt.NDArray, SupportsFloat, bool, bool, Dict[str, Any]]: """Returns (obs, rew, terminated, truncated, info) with masked obs.""" obs, rew, terminated, truncated, info = self.env.step(action) diff --git a/tests/test_envs.py b/tests/test_envs.py index 4c6b3f9..6a449ad 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -13,8 +13,7 @@ from seals.testing.envs import is_mujoco_env ENV_NAMES: List[str] = [ - env_id for env_id in registration.registry.keys() - if env_id.startswith("seals/") + env_id for env_id in registration.registry.keys() if env_id.startswith("seals/") ] From 73666a18622b33d5418757b0820c08eec64aed0e Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Mon, 14 Aug 2023 20:05:44 +0200 Subject: [PATCH 41/61] Explicitly seed dummy environment. --- tests/test_base_env.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_base_env.py b/tests/test_base_env.py index 9b0ef05..8d336c0 100644 --- a/tests/test_base_env.py +++ b/tests/test_base_env.py @@ -17,6 +17,7 @@ class NewEnv(base_envs.TabularModelMDP): def __init__(self): """Build environment.""" + np.random.seed(0) nS = 3 nA = 2 transition_matrix = np.random.random((nS, nA, nS)) From 04e4442c9355376f8097db21a54eb99e9ba9a6aa Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Mon, 14 Aug 2023 20:26:58 +0200 Subject: [PATCH 42/61] Remove unnecessary cast to int. --- src/seals/base_envs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/seals/base_envs.py b/src/seals/base_envs.py index 03a83f0..50f87d1 100644 --- a/src/seals/base_envs.py +++ b/src/seals/base_envs.py @@ -316,7 +316,7 @@ def feature_matrix(self): # Construct lazily to save memory in algorithms that don't need features. if self._feature_matrix is None: n_states = self.state_space.n - self._feature_matrix = np.eye(int(n_states)) + self._feature_matrix = np.eye(n_states) return self._feature_matrix @property From fd278cab8dcde5ff580de74ac878e40087cc1bf5 Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Mon, 14 Aug 2023 22:47:09 +0200 Subject: [PATCH 43/61] Fix some typing issues. --- src/seals/base_envs.py | 8 +++----- src/seals/testing/envs.py | 8 ++++---- src/seals/util.py | 4 ++-- tests/test_envs.py | 8 +++++--- tests/test_wrappers.py | 6 ++++-- 5 files changed, 18 insertions(+), 16 deletions(-) diff --git a/src/seals/base_envs.py b/src/seals/base_envs.py index 50f87d1..8872456 100644 --- a/src/seals/base_envs.py +++ b/src/seals/base_envs.py @@ -5,14 +5,13 @@ import gymnasium as gym from gymnasium import spaces +from gymnasium.core import ActType, ObsType import numpy as np import numpy.typing as npt from seals import util StateType = TypeVar("StateType") -ObsType = TypeVar("ObsType") -ActType = TypeVar("ActType") class ResettablePOMDP( @@ -162,7 +161,7 @@ class ResettableMDP( """ABC for MDPs that are resettable.""" @property - def observation_space(self) -> spaces.Space[StateType]: + def observation_space(self): """Observation space.""" return self.state_space @@ -333,7 +332,6 @@ def action_dim(self) -> int: ObsEntryType = TypeVar( "ObsEntryType", bound=Union[np.floating, np.integer], - covariant=True, ) @@ -390,7 +388,7 @@ def __init__( low=min_val, high=max_val, shape=(self.obs_dim,), - dtype=self.obs_dtype, + dtype=self.obs_dtype, # type: ignore ) def obs_from_state(self, state: DiscreteSpaceInt) -> npt.NDArray[ObsEntryType]: diff --git a/src/seals/testing/envs.py b/src/seals/testing/envs.py index 3f20879..47085d1 100644 --- a/src/seals/testing/envs.py +++ b/src/seals/testing/envs.py @@ -12,15 +12,15 @@ Iterator, List, Mapping, - Optional, Sequence, + SupportsFloat, Tuple, ) import gymnasium as gym import numpy as np -Step = Tuple[Any, Optional[float], bool, bool, Mapping[str, Any]] +Step = Tuple[Any, SupportsFloat, bool, bool, Mapping[str, Any]] Rollout = Sequence[Step] """A sequence of 5-tuples (obs, rew, terminated, truncated, info) as returned by `get_rollout`.""" @@ -99,7 +99,7 @@ def get_rollout(env: gym.Env, actions: Iterable[Any]) -> Rollout: A sequence of 5-tuples (obs, rew, terminated, truncated, info). """ obs, info = env.reset() - ret: List[Step] = [(obs, None, False, False, {})] + ret: List[Step] = [(obs, 0, False, False, {})] for act in actions: ret.append(env.step(act)) return ret @@ -192,7 +192,7 @@ def _check_obs(obs: np.ndarray, obs_space: gym.Space) -> None: assert obs in obs_space -def _sample_and_check(env: gym.Env, obs_space: gym.Space) -> Tuple[bool, bool]: +def _sample_and_check(env: gym.Env, obs_space: gym.Space) -> bool: """Sample from env and check return value is of valid type.""" act = env.action_space.sample() obs, rew, terminated, truncated, info = env.step(act) diff --git a/src/seals/util.py b/src/seals/util.py index 2843e41..31f82f8 100644 --- a/src/seals/util.py +++ b/src/seals/util.py @@ -21,7 +21,7 @@ class AutoResetWrapper( gym.Wrapper, - Generic[WrapperObsType, WrapperActType, ObsType, ActType], + Generic[WrapperObsType, WrapperActType, ObsType, ActType], # type: ignore ): """Hides terminated truncated and auto-resets at the end of each episode. @@ -132,7 +132,7 @@ class BoxRegion: class MaskScoreWrapper( gym.Wrapper[npt.NDArray, ActType, npt.NDArray, ActType], - Generic[ActType], + Generic[ActType], # type: ignore ): """Mask a list of box-shaped regions in the observation to hide reward info. diff --git a/tests/test_envs.py b/tests/test_envs.py index 6a449ad..2cbce74 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -1,6 +1,6 @@ """Smoke tests for all environments.""" -from typing import List +from typing import List, Union import gymnasium as gym from gymnasium.envs import registration @@ -165,8 +165,10 @@ def test_render_modes(self, env_name: str): if mode == "rgb_array" and not is_mujoco_env(env): # The render should not change without calling `step()`. # MuJoCo rendering fails this check, ignore -- not much we can do. - r1 = env.render() - r2 = env.render() + r1: Union[np.ndarray, List[np.ndarray], None] = env.render() + r2: Union[np.ndarray, List[np.ndarray], None] = env.render() + assert r1 is not None + assert r2 is not None assert np.allclose(r1, r2) else: env.render() diff --git a/tests/test_wrappers.py b/tests/test_wrappers.py index f7f5b26..a1c2d1d 100644 --- a/tests/test_wrappers.py +++ b/tests/test_wrappers.py @@ -157,8 +157,10 @@ def test_obs_cast(dtype: np.dtype, episode_length: int = 5): Test uses CountingEnv with small integers, which can be represented in all the specified dtypes without any loss of precision. """ - env = envs.CountingEnv(episode_length=episode_length) - env = util.ObsCastWrapper(env, dtype) + env = util.ObsCastWrapper( + envs.CountingEnv(episode_length=episode_length), + dtype, + ) obs, info = env.reset() assert obs.dtype == dtype From 2702330413d55b7b3e9332c08b96f638074e3f13 Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Tue, 15 Aug 2023 10:18:17 +0200 Subject: [PATCH 44/61] Fix more typing issues. --- src/seals/base_envs.py | 5 ++++- src/seals/util.py | 13 ++++++++++--- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/src/seals/base_envs.py b/src/seals/base_envs.py index 8872456..631e283 100644 --- a/src/seals/base_envs.py +++ b/src/seals/base_envs.py @@ -5,13 +5,16 @@ import gymnasium as gym from gymnasium import spaces -from gymnasium.core import ActType, ObsType import numpy as np import numpy.typing as npt from seals import util +# Note: we redefine the type vars from gymnasium.core here, because pytype does not +# recognize them as valid type vars if we import them from gymnasium.core. StateType = TypeVar("StateType") +ActType = TypeVar("ActType") +ObsType = TypeVar("ObsType") class ResettablePOMDP( diff --git a/src/seals/util.py b/src/seals/util.py index 31f82f8..9535958 100644 --- a/src/seals/util.py +++ b/src/seals/util.py @@ -10,18 +10,25 @@ Sequence, SupportsFloat, Tuple, + TypeVar, Union, ) import gymnasium as gym -from gymnasium.core import ActType, ObsType, WrapperActType, WrapperObsType import numpy as np import numpy.typing as npt +# Note: we redefine the type vars from gymnasium.core here, because pytype does not +# recognize them as valid type vars if we import them from gymnasium.core. +WrapperObsType = TypeVar("WrapperObsType") +WrapperActType = TypeVar("WrapperActType") +ObsType = TypeVar("ObsType") +ActType = TypeVar("ActType") + class AutoResetWrapper( gym.Wrapper, - Generic[WrapperObsType, WrapperActType, ObsType, ActType], # type: ignore + Generic[WrapperObsType, WrapperActType, ObsType, ActType], ): """Hides terminated truncated and auto-resets at the end of each episode. @@ -132,7 +139,7 @@ class BoxRegion: class MaskScoreWrapper( gym.Wrapper[npt.NDArray, ActType, npt.NDArray, ActType], - Generic[ActType], # type: ignore + Generic[ActType], ): """Mask a list of box-shaped regions in the observation to hide reward info. From a8310d8560fe7fe81578de171ed36e76195a528e Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Tue, 15 Aug 2023 14:21:57 +0200 Subject: [PATCH 45/61] Simplify ObsCastWrapper by inheriting from gym.ObservationWrapper instead of gym.Wrapper. --- src/seals/util.py | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/src/seals/util.py b/src/seals/util.py index 9535958..ada9415 100644 --- a/src/seals/util.py +++ b/src/seals/util.py @@ -195,7 +195,7 @@ def reset(self, **kwargs): return self._mask_obs(obs), info -class ObsCastWrapper(gym.Wrapper): +class ObsCastWrapper(gym.ObservationWrapper): """Cast observations to specified dtype. Some external environments return observations of a different type than the @@ -214,15 +214,9 @@ def __init__(self, env: gym.Env, dtype: np.dtype): super().__init__(env) self.dtype = dtype - def reset(self, seed=None): - """Returns reset observation, cast to self.dtype.""" - obs, info = super().reset(seed=seed) - return obs.astype(self.dtype), info - - def step(self, action): - """Returns (obs, rew, terminated, truncated, info) with obs cast to dtype.""" - obs, rew, terminated, truncated, info = super().step(action) - return obs.astype(self.dtype), rew, terminated, truncated, info + def observation(self, obs): + """Returns observation cast to self.dtype.""" + return obs.astype(self.dtype) class AbsorbAfterDoneWrapper(gym.Wrapper): From e0d954f1ebc291c43e5e7d9040b21a1b9785c2f9 Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Tue, 15 Aug 2023 14:22:09 +0200 Subject: [PATCH 46/61] Small typos in docstrings. --- src/seals/testing/envs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/seals/testing/envs.py b/src/seals/testing/envs.py index 47085d1..ff61c6a 100644 --- a/src/seals/testing/envs.py +++ b/src/seals/testing/envs.py @@ -89,10 +89,10 @@ def matches_list(env_name: str, patterns: Iterable[str]) -> bool: def get_rollout(env: gym.Env, actions: Iterable[Any]) -> Rollout: - """Performs sequence of actions `actions` in `env`. + """Performs a sequence of actions `actions` in `env`. Args: - env: the environment to rollout in. + env: the environment to roll out in. actions: the actions to perform. Returns: From d1e4f1a583943633ae9787dd16742b19284b39df Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Tue, 15 Aug 2023 14:22:22 +0200 Subject: [PATCH 47/61] Add reset info when generating rollouts. --- src/seals/testing/envs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/seals/testing/envs.py b/src/seals/testing/envs.py index ff61c6a..4d6dc0c 100644 --- a/src/seals/testing/envs.py +++ b/src/seals/testing/envs.py @@ -99,7 +99,7 @@ def get_rollout(env: gym.Env, actions: Iterable[Any]) -> Rollout: A sequence of 5-tuples (obs, rew, terminated, truncated, info). """ obs, info = env.reset() - ret: List[Step] = [(obs, 0, False, False, {})] + ret: List[Step] = [(obs, 0, False, False, info)] for act in actions: ret.append(env.step(act)) return ret From 36041d7d7b67d59bc4f0cb4e6be41b08e9391bb2 Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Tue, 15 Aug 2023 14:24:25 +0200 Subject: [PATCH 48/61] Remove unneeded default params to rand_gen.normal() --- src/seals/diagnostics/random_trans.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/seals/diagnostics/random_trans.py b/src/seals/diagnostics/random_trans.py index 45f717b..6071e88 100644 --- a/src/seals/diagnostics/random_trans.py +++ b/src/seals/diagnostics/random_trans.py @@ -73,11 +73,7 @@ def __init__( rand_state=rand_gen, ) - self.reward_weights = rand_gen.normal( - 0, - 1, - size=(observation_matrix.shape[-1],), - ) + self.reward_weights = rand_gen.normal(size=(observation_matrix.shape[-1],)) reward_matrix = observation_matrix @ self.reward_weights super().__init__( transition_matrix=transition_matrix, From 52bffb679a3fdf2969c0244e6cf5fe5381bacb94 Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Tue, 15 Aug 2023 14:35:29 +0200 Subject: [PATCH 49/61] Remove unneeded setter for the observation space property in a ResettableMDP. --- src/seals/base_envs.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/seals/base_envs.py b/src/seals/base_envs.py index 631e283..f5f9681 100644 --- a/src/seals/base_envs.py +++ b/src/seals/base_envs.py @@ -168,11 +168,6 @@ def observation_space(self): """Observation space.""" return self.state_space - @observation_space.setter - def observation_space(self, space: spaces.Space[StateType]): - """Set observation space.""" - self.state_space = space - def obs_from_state(self, state: StateType) -> StateType: """Identity since observation == state in an MDP.""" return state From 9ea13ba48892738ce4a89ecb2635aa51c9b43c83 Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Tue, 15 Aug 2023 14:49:50 +0200 Subject: [PATCH 50/61] Ignore coverage for edge cases of where the observation space has no shape. --- src/seals/util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/seals/util.py b/src/seals/util.py index ada9415..6299bdc 100644 --- a/src/seals/util.py +++ b/src/seals/util.py @@ -171,7 +171,7 @@ def __init__( self.fill_value = np.array(fill_value, env.observation_space.dtype) if env.observation_space.shape is None: - raise ValueError("Observation space must have a shape.") + raise ValueError("Observation space must have a shape.") # pragma: no cover self.mask = np.ones(env.observation_space.shape, dtype=bool) for r in score_regions: if r.x[0] >= r.x[1] or r.y[0] >= r.y[1]: From d7cbaa3cbfdb740d93e4a655a3ac9c713c03c340 Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Tue, 15 Aug 2023 14:50:51 +0200 Subject: [PATCH 51/61] Add a test case that ensures that options in the reset to a ResettablePOMDP are rejected. --- tests/test_base_env.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_base_env.py b/tests/test_base_env.py index 8d336c0..9f089f3 100644 --- a/tests/test_base_env.py +++ b/tests/test_base_env.py @@ -52,6 +52,9 @@ def test_base_envs(): with pytest.raises(ValueError, match=r".*not in.*"): env.state = bad_state # type: ignore + with pytest.raises(NotImplementedError, match=r"Options not supported.*"): + env.reset(options={"option": "value"}) + def test_tabular_env_validation(): """Test input validation for base_envs.TabularModelEnv.""" From c990834a272fb748d21b29a36b334868199a79f8 Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Tue, 15 Aug 2023 14:51:52 +0200 Subject: [PATCH 52/61] Remove rand_state property of ResettablePOMDP and use the canonical np_random of the superclass instead. --- src/seals/base_envs.py | 11 ++--------- src/seals/diagnostics/largest_sum.py | 2 +- src/seals/diagnostics/noisy_obs.py | 4 ++-- src/seals/diagnostics/parabola.py | 2 +- src/seals/diagnostics/proc_goal.py | 6 +++--- src/seals/diagnostics/sort.py | 2 +- 6 files changed, 10 insertions(+), 17 deletions(-) diff --git a/src/seals/base_envs.py b/src/seals/base_envs.py index f5f9681..3bb17eb 100644 --- a/src/seals/base_envs.py +++ b/src/seals/base_envs.py @@ -115,13 +115,6 @@ def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]: infos = {"old_state": old_state, "new_state": self._cur_state} return obs, reward, terminated, truncated, infos - @property - def rand_state(self) -> np.random.Generator: - """Random state.""" - rand_state = self._np_random - if rand_state is None: - raise RuntimeError("Need to call reset() before accessing rand_state") - return rand_state class ExposePOMDPStateWrapper( @@ -275,7 +268,7 @@ def initial_state(self) -> DiscreteSpaceInt: return DiscreteSpaceInt( util.sample_distribution( self.initial_state_dist, - random=self.rand_state, + random=self.np_random, ), ) @@ -288,7 +281,7 @@ def transition( return DiscreteSpaceInt( util.sample_distribution( self.transition_matrix[state, action], - random=self.rand_state, + random=self.np_random, ), ) diff --git a/src/seals/diagnostics/largest_sum.py b/src/seals/diagnostics/largest_sum.py index 35a03f7..7a08dad 100644 --- a/src/seals/diagnostics/largest_sum.py +++ b/src/seals/diagnostics/largest_sum.py @@ -34,7 +34,7 @@ def terminal(self, state: np.ndarray, n_actions_taken: int) -> bool: def initial_state(self) -> np.ndarray: """Returns vector sampled uniformly in [0, 1]**L.""" - init_state = self.rand_state.random((self._length,)) + init_state = self.np_random.random((self._length,)) return init_state.astype(self.observation_space.dtype) def reward(self, state: np.ndarray, act: int, next_state: np.ndarray) -> float: diff --git a/src/seals/diagnostics/noisy_obs.py b/src/seals/diagnostics/noisy_obs.py index 5fd8730..92b3429 100644 --- a/src/seals/diagnostics/noisy_obs.py +++ b/src/seals/diagnostics/noisy_obs.py @@ -52,7 +52,7 @@ def initial_state(self) -> np.ndarray: """Returns one of the grid's corners.""" n = self._size corners = np.array([[0, 0], [n - 1, 0], [0, n - 1], [n - 1, n - 1]]) - return corners[self.rand_state.integers(4)] + return corners[self.np_random.integers(4)] def reward(self, state: np.ndarray, action: int, new_state: np.ndarray) -> float: """Returns +1.0 reward if state is the goal and 0.0 otherwise.""" @@ -69,5 +69,5 @@ def transition(self, state: np.ndarray, action: int) -> np.ndarray: def obs_from_state(self, state: np.ndarray) -> np.ndarray: """Returns (x, y) concatenated with Gaussian noise.""" - noise_vector = self.rand_state.normal(size=self._noise_length) + noise_vector = self.np_random.normal(size=self._noise_length) return np.concatenate([state, noise_vector]).astype(np.float32) diff --git a/src/seals/diagnostics/parabola.py b/src/seals/diagnostics/parabola.py index 47483ed..7b0c677 100644 --- a/src/seals/diagnostics/parabola.py +++ b/src/seals/diagnostics/parabola.py @@ -40,7 +40,7 @@ def terminal(self, state: int, n_actions_taken: int) -> bool: def initial_state(self) -> np.ndarray: """Get state by sampling a random parabola.""" - a, b, c = -1 + 2 * self.rand_state.random((3,)) + a, b, c = -1 + 2 * self.np_random.random((3,)) x, y = 0, c return np.array([x, y, a, b, c], dtype=self.state_space.dtype) diff --git a/src/seals/diagnostics/proc_goal.py b/src/seals/diagnostics/proc_goal.py index d212f44..5136db8 100644 --- a/src/seals/diagnostics/proc_goal.py +++ b/src/seals/diagnostics/proc_goal.py @@ -40,11 +40,11 @@ def terminal(self, state: np.ndarray, n_actions_taken: int) -> bool: def initial_state(self) -> np.ndarray: """Samples random agent position and random goal.""" - pos = self.rand_state.integers(low=-self._bounds, high=self._bounds, size=(2,)) + pos = self.np_random.integers(low=-self._bounds, high=self._bounds, size=(2,)) - x_dist = self.rand_state.integers(self._distance) + x_dist = self.np_random.integers(self._distance) y_dist = self._distance - x_dist - random_signs = 2 * self.rand_state.integers(2, size=2) - 1 + random_signs = 2 * self.np_random.integers(2, size=2) - 1 goal = pos + random_signs * (x_dist, y_dist) return np.concatenate([pos, goal]).astype(self.observation_space.dtype) diff --git a/src/seals/diagnostics/sort.py b/src/seals/diagnostics/sort.py index 0ae8621..6f564fc 100644 --- a/src/seals/diagnostics/sort.py +++ b/src/seals/diagnostics/sort.py @@ -31,7 +31,7 @@ def terminal(self, state: np.ndarray, n_actions_taken: int) -> bool: def initial_state(self): """Sample random vector uniformly in [0, 1]**L.""" - sample = self.rand_state.random(size=self._length) + sample = self.np_random.random(size=self._length) return sample.astype(self.state_space.dtype) def reward( From 63009afd22b22fdc01f59723eefc253492a35a91 Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Tue, 15 Aug 2023 14:51:52 +0200 Subject: [PATCH 53/61] Remove newline in base_envs.py --- src/seals/base_envs.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/seals/base_envs.py b/src/seals/base_envs.py index 3bb17eb..57638ac 100644 --- a/src/seals/base_envs.py +++ b/src/seals/base_envs.py @@ -116,7 +116,6 @@ def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]: return obs, reward, terminated, truncated, infos - class ExposePOMDPStateWrapper( gym.Wrapper[StateType, ActType, ObsType, ActType], Generic[StateType, ObsType, ActType], @@ -268,7 +267,7 @@ def initial_state(self) -> DiscreteSpaceInt: return DiscreteSpaceInt( util.sample_distribution( self.initial_state_dist, - random=self.np_random, + random=self.rand_state, ), ) @@ -281,7 +280,7 @@ def transition( return DiscreteSpaceInt( util.sample_distribution( self.transition_matrix[state, action], - random=self.np_random, + random=self.rand_state, ), ) From 6ee0aa9ba7c8344db3723c40ff4c025887c1e3cf Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Tue, 15 Aug 2023 15:03:16 +0200 Subject: [PATCH 54/61] Fix type annotations of FixedHorizonCartPole.reset() --- src/seals/classic_control.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/seals/classic_control.py b/src/seals/classic_control.py index dc1ed8f..bf856f9 100644 --- a/src/seals/classic_control.py +++ b/src/seals/classic_control.py @@ -1,5 +1,6 @@ """Adaptation of classic Gym environments for specification learning algorithms.""" +from typing import Any, Dict, Optional import warnings import gymnasium as gym @@ -34,7 +35,11 @@ def __init__(self, *args, **kwargs): high = np.array(high) self.observation_space = spaces.Box(-high, high, dtype=np.float32) - def reset(self, seed=None, options={}): + def reset( + self, + seed: Optional[int] = None, + options: Optional[Dict[str, Any]] = None, + ): """Reset for FixedHorizonCartPole.""" observation, info = super().reset(seed=seed, options=options) return observation.astype(np.float32), info From 63a76381b0a3c5d47dd99615cf77b84fcda2e2b7 Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Tue, 15 Aug 2023 15:18:30 +0200 Subject: [PATCH 55/61] Remove leftover usages of rand_state. --- src/seals/base_envs.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/seals/base_envs.py b/src/seals/base_envs.py index 57638ac..a0e42a2 100644 --- a/src/seals/base_envs.py +++ b/src/seals/base_envs.py @@ -267,7 +267,7 @@ def initial_state(self) -> DiscreteSpaceInt: return DiscreteSpaceInt( util.sample_distribution( self.initial_state_dist, - random=self.rand_state, + random=self.np_random, ), ) @@ -280,7 +280,7 @@ def transition( return DiscreteSpaceInt( util.sample_distribution( self.transition_matrix[state, action], - random=self.rand_state, + random=self.np_random, ), ) From 9d73770e8cab5fe1b5713415b586fff91dacc32b Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Mon, 28 Aug 2023 14:32:28 +0200 Subject: [PATCH 56/61] Fix quicks in dependencies that are no longer needed. --- setup.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/setup.py b/setup.py index a31f9f1..a2a3139 100644 --- a/setup.py +++ b/setup.py @@ -94,9 +94,8 @@ def get_readme() -> str: "shimmy[atari] >=0.1.0,<1.0", ] TESTS_REQUIRE = [ - # remove pin once https://github.com/nedbat/coveragepy/issues/881 fixed "black", - "coverage==4.5.4", + "coverage~=4.5.4", "codecov", "codespell", "darglint>=1.5.6", @@ -117,10 +116,7 @@ def get_readme() -> str: "pytype", "stable-baselines3>=0.9.0", "setuptools_scm~=7.0.5", - # We'd like to specify `gymnasium[classic-control]`, but this is a no-op when - # gymnasium is already installed. See https://github.com/pypa/pip/issues/4957 for - # issue. - "pygame>=2.1.3", + "gymnasium[classic-control]", *ATARI_REQUIRE, ] DOCS_REQUIRE = [ @@ -148,9 +144,7 @@ def get_readme() -> str: "dev": ["ipdb", "jupyter", *TESTS_REQUIRE, *DOCS_REQUIRE], "docs": DOCS_REQUIRE, "test": TESTS_REQUIRE, - # We'd like to specify `gymnasium[mujoco]`, but this is a no-op when gymnasium - # is already installed. See https://github.com/pypa/pip/issues/4957 for issue. - "mujoco": ["mujoco", "imageio"], + "mujoco": ["gymnasium[mujoco]"], "atari": ATARI_REQUIRE, }, url="https://github.com/HumanCompatibleAI/benchmark-environments", From e137c24c6c6ce39fa2b9c54749518627abef070d Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Mon, 28 Aug 2023 14:33:52 +0200 Subject: [PATCH 57/61] Store unused info in _ --- tests/test_wrappers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_wrappers.py b/tests/test_wrappers.py index a1c2d1d..136db64 100644 --- a/tests/test_wrappers.py +++ b/tests/test_wrappers.py @@ -162,7 +162,7 @@ def test_obs_cast(dtype: np.dtype, episode_length: int = 5): dtype, ) - obs, info = env.reset() + obs, _ = env.reset() assert obs.dtype == dtype assert obs == 0 for t in range(1, episode_length + 1): From b730bf353e3a7fe9ee1d8703a4a49414a81abef1 Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Mon, 28 Aug 2023 14:43:10 +0200 Subject: [PATCH 58/61] Make test_sample_distribution by seeding the used rng instead of setting the global seed. --- tests/test_util.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/test_util.py b/tests/test_util.py index 1a569a0..b1ca318 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -21,13 +21,12 @@ def test_mask_score_wrapper_enforces_spec(): def test_sample_distribution(): """Test util.sample_distribution.""" - np.random.seed(0) distr_size = 5 distr = np.random.random((distr_size,)) distr /= distr.sum() n_samples = 1000 - rng = np.random.default_rng() + rng = np.random.default_rng(0) sample_count = collections.Counter( util.sample_distribution(distr, rng) for _ in range(n_samples) ) From 21f33dd33a2f328fc4a2c3d48c2671c5cf585361 Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Mon, 28 Aug 2023 16:14:51 +0200 Subject: [PATCH 59/61] Add missing test dependency. --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index a2a3139..f2cee15 100644 --- a/setup.py +++ b/setup.py @@ -116,7 +116,7 @@ def get_readme() -> str: "pytype", "stable-baselines3>=0.9.0", "setuptools_scm~=7.0.5", - "gymnasium[classic-control]", + "gymnasium[classic-control,mujoco]", *ATARI_REQUIRE, ] DOCS_REQUIRE = [ From 7353fcb678c9c6f062e3fe5700009da33ea2f5fb Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Tue, 29 Aug 2023 17:08:01 +0200 Subject: [PATCH 60/61] Ensure we have the newest pip version to make the dependency resolution work. --- ci/build_venv.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/ci/build_venv.sh b/ci/build_venv.sh index 672ecea..841ed47 100755 --- a/ci/build_venv.sh +++ b/ci/build_venv.sh @@ -9,4 +9,5 @@ fi virtualenv -p python3.8 ${venv} source ${venv}/bin/activate +pip install --upgrade pip # Ensure we have the newest pip pip install .[cpu,docs,mujoco,test] From 1cd453092638f3dc3b05baa25a372e6d080a5dcc Mon Sep 17 00:00:00 2001 From: Maximilian Ernestus Date: Tue, 29 Aug 2023 17:13:35 +0200 Subject: [PATCH 61/61] Make the dependencies cache also dependent on ci/build_venv.sh --- .circleci/config.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 49c49b0..b998576 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -51,7 +51,7 @@ commands: # released that you want to upgrade to, without mandating the newer version in setup.py. - restore_cache: keys: - - v2-dependencies-{{ checksum "setup.py" }} + - v2-dependencies-{{ checksum "setup.py" }}-{{ checksum "ci/build_venv.sh" }} # Create virtual environment and install dependencies using `ci/build_venv.sh`. # `mujoco_py` needs a MuJoCo key, so download that first. @@ -64,7 +64,7 @@ commands: - save_cache: paths: - /venv - key: v2-dependencies-{{ checksum "setup.py" }} + key: v2-dependencies-{{ checksum "setup.py" }}-{{ checksum "ci/build_venv.sh" }} # Install seals. # Note we install the source distribution, not in developer mode (`pip install -e`).