diff --git a/README.md b/README.md index eacef50..e4078d8 100644 --- a/README.md +++ b/README.md @@ -7,39 +7,39 @@ **Status**: early beta. -*seals*, the Suite of Environments for Algorithms that Learn Specifications, is a toolkit for +_seals_, the Suite of Environments for Algorithms that Learn Specifications, is a toolkit for evaluating specification learning algorithms, such as reward or imitation learning. The environments are compatible with [Gym](https://github.com/openai/gym), but are designed to test algorithms that learn from user data, without requiring a procedurally specified reward function. -There are two types of environments in *seals*: +There are two types of environments in _seals_: - - **Diagnostic Tasks** which test individual facets of algorithm performance in isolation. - - **Renovated Environments**, adaptations of widely-used benchmarks such as MuJoCo continuous - control tasks and Atari games to be suitable for specification learning benchmarks. In particular, - we remove any side-channel sources of reward information from MuJoCo tasks, and give Atari games constant-length episodes (although most Atari environments have observations that include the score). +- **Diagnostic Tasks** which test individual facets of algorithm performance in isolation. +- **Renovated Environments**, adaptations of widely-used benchmarks such as MuJoCo continuous + control tasks and Atari games to be suitable for specification learning benchmarks. In particular, + we remove any side-channel sources of reward information from MuJoCo tasks, and give Atari games constant-length episodes (although most Atari environments have observations that include the score). + +_seals_ is under active development and we intend to add more categories of tasks soon. -*seals* is under active development and we intend to add more categories of tasks soon. - You may also be interested in our sister project [imitation](https://github.com/humancompatibleai/imitation/), providing implementations of a variety of imitation and reward learning algorithms. -Check out our [documentation](https://seals.readthedocs.io/en/latest/) for more information about *seals*. +Check out our [documentation](https://seals.readthedocs.io/en/latest/) for more information about _seals_. # Quickstart To install the latest release from PyPI, run: - + ```bash pip install seals ``` -All *seals* environments are available in the Gym registry. Simply import it and then use as you +All _seals_ environments are available in the Gym registry. Simply import it and then use as you would with your usual RL or specification learning algroithm: ```python -import gym +import gymnasium as gym import seals env = gym.make('seals/CartPole-v0') @@ -86,7 +86,7 @@ for type checking. ## Workflow Trivial changes (e.g. typo fixes) may be made directly by maintainers. Any non-trivial changes -must be proposed in a PR and approved by at least one maintainer. PRs must pass the continuous +must be proposed in a PR and approved by at least one maintainer. PRs must pass the continuous integration tests (CircleCI linting, type checking, unit tests and CodeCov) to be merged. It is often helpful to open an issue before proposing a PR, to allow for discussion of the design diff --git a/pyproject.toml b/pyproject.toml index f9bcbe8..bdde550 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,8 +6,8 @@ build-backend = "setuptools.build_meta" target-version = ["py38"] [[tool.mypy.overrides]] -module = [ - "gym.*", - "setuptools_scm.*", -] +module = ["gym.*", "setuptools_scm.*"] ignore_missing_imports = true + +[tool.ruff] +select = ["E", "F"] diff --git a/setup.py b/setup.py index 84ba62a..40af813 100644 --- a/setup.py +++ b/setup.py @@ -115,10 +115,6 @@ def get_readme() -> str: "pytest-xdist", "pytype", "stable-baselines3>=0.9.0", - # TODO(adam): remove pyglet pin once Gym upgraded to >0.21 - # Workaround for https://github.com/openai/gym/issues/2986 - # Discussed in https://github.com/HumanCompatibleAI/imitation/pull/603 - "pyglet==1.5.27", "setuptools_scm~=7.0.5", *ATARI_REQUIRE, ] @@ -140,7 +136,7 @@ def get_readme() -> str: packages=find_packages("src"), package_dir={"": "src"}, package_data={"seals": ["py.typed"]}, - install_requires=["gym", "numpy"], + install_requires=["gymnasium", "numpy"], tests_require=TESTS_REQUIRE, extras_require={ # recommended packages for development @@ -149,7 +145,7 @@ def get_readme() -> str: "test": TESTS_REQUIRE, # We'd like to specify `gym[mujoco]`, but this is a no-op when Gym is already # installed. See https://github.com/pypa/pip/issues/4957 for issue. - "mujoco": ["mujoco_py>=1.50, <2.0", "imageio"], + "mujoco": ["mujoco", "imageio"], "atari": ATARI_REQUIRE, }, url="https://github.com/HumanCompatibleAI/benchmark-environments", diff --git a/src/seals/__init__.py b/src/seals/__init__.py index 7f017e4..fc15b21 100644 --- a/src/seals/__init__.py +++ b/src/seals/__init__.py @@ -2,10 +2,10 @@ from importlib import metadata -import gym +import gymnasium as gym -from seals import atari, util import seals.diagnostics # noqa: F401 +from seals import atari, util try: __version__ = metadata.version("seals") @@ -38,5 +38,5 @@ # Atari -GYM_ATARI_ENV_SPECS = list(filter(atari._supported_atari_env, gym.envs.registry.all())) +GYM_ATARI_ENV_SPECS = list(filter(atari._supported_atari_env, gym.registry.values())) atari.register_atari_envs(GYM_ATARI_ENV_SPECS) diff --git a/src/seals/atari.py b/src/seals/atari.py index a2d896c..9da2cce 100644 --- a/src/seals/atari.py +++ b/src/seals/atari.py @@ -2,7 +2,8 @@ from typing import Dict, Iterable, Optional -import gym +import gymnasium as gym +from gymnasium.envs.registration import EnvSpec from seals.util import ( AutoResetWrapper, @@ -37,7 +38,7 @@ def _get_score_region(atari_env_id: str) -> Optional[MaskedRegionSpecifier]: def make_atari_env(atari_env_id: str, masked: bool) -> gym.Env: """Fixed-length, optionally masked-score variant of a given Atari environment.""" - env = AutoResetWrapper(gym.make(atari_env_id)) + env: gym.Env = AutoResetWrapper(gym.make(atari_env_id)) if masked: score_region = _get_score_region(atari_env_id) @@ -59,15 +60,15 @@ def _not_ram_or_det(env_id: str) -> bool: after_slash = slash_separated[-1] hyphen_separated = after_slash.split("-") assert len(hyphen_separated) > 1 - not_ram = not ("ram" in hyphen_separated[1]) - not_deterministic = not ("Deterministic" in env_id) + not_ram = "ram" not in hyphen_separated[1] + not_deterministic = "Deterministic" not in env_id return not_ram and not_deterministic -def _supported_atari_env(gym_spec: gym.envs.registration.EnvSpec) -> bool: +def _supported_atari_env(gym_spec: EnvSpec) -> bool: """Checks if a gym Atari environment is one of the ones we will support.""" is_atari = gym_spec.entry_point == "gym.envs.atari:AtariEnv" - v5_and_plain = gym_spec.id.endswith("-v5") and not ("NoFrameskip" in gym_spec.id) + v5_and_plain = gym_spec.id.endswith("-v5") and "NoFrameskip" not in gym_spec.id v4_and_no_frameskip = gym_spec.id.endswith("-v4") and "NoFrameskip" in gym_spec.id return ( is_atari @@ -76,7 +77,7 @@ def _supported_atari_env(gym_spec: gym.envs.registration.EnvSpec) -> bool: ) -def _seals_name(gym_spec: gym.envs.registration.EnvSpec, masked: bool) -> str: +def _seals_name(gym_spec: EnvSpec, masked: bool) -> str: """Makes a Gym ID for an Atari environment in the seals namespace.""" slash_separated = gym_spec.id.split("/") name = "seals/" + slash_separated[-1] @@ -88,7 +89,7 @@ def _seals_name(gym_spec: gym.envs.registration.EnvSpec, masked: bool) -> str: def register_atari_envs( - gym_atari_env_specs: Iterable[gym.envs.registration.EnvSpec], + gym_atari_env_specs: Iterable[EnvSpec], ) -> None: """Register masked and unmasked wrapped gym Atari environments.""" diff --git a/src/seals/base_envs.py b/src/seals/base_envs.py index c7b29c6..b690f14 100644 --- a/src/seals/base_envs.py +++ b/src/seals/base_envs.py @@ -1,20 +1,23 @@ """Base environment classes.""" import abc -from typing import Generic, Optional, Sequence, Tuple, TypeVar +from typing import Any, Generic, Optional, Tuple, TypeVar -import gym -from gym import spaces +import gymnasium as gym import numpy as np +import numpy.typing as npt +from gymnasium import spaces from seals import util -State = TypeVar("State") -Observation = TypeVar("Observation") -Action = TypeVar("Action") +StateType = TypeVar("StateType") +ObsType = TypeVar("ObsType") +ActType = TypeVar("ActType") -class ResettablePOMDP(gym.Env, abc.ABC, Generic[State, Observation, Action]): +class ResettablePOMDP( + gym.Env[ObsType, ActType], abc.ABC, Generic[StateType, ObsType, ActType] +): """ABC for POMDPs that are resettable. Specifically, these environments provide oracle access to sample from @@ -23,69 +26,37 @@ class ResettablePOMDP(gym.Env, abc.ABC, Generic[State, Observation, Action]): meet these criteria. """ - _state_space: gym.Space - _observation_space: gym.Space - _action_space: gym.Space - _cur_state: Optional[State] + state_space: spaces.Space[StateType] + + _cur_state: Optional[StateType] _n_actions_taken: Optional[int] - def __init__( - self, - *, - state_space: gym.Space, - observation_space: gym.Space, - action_space: gym.Space, - ): - """Build resettable (PO)MDP. - - Args: - state_space: gym.Space containing possible states. - observation_space: gym.Space containing possible observations. - action_space: gym.Space containing possible actions. - """ - self._state_space = state_space - self._observation_space = observation_space - self._action_space = action_space + def __init__(self): + """Build resettable (PO)MDP.""" self._cur_state = None self._n_actions_taken = None - self.seed() @abc.abstractmethod - def initial_state(self) -> State: + def initial_state(self) -> StateType: """Samples from the initial state distribution.""" @abc.abstractmethod - def transition(self, state: State, action: Action) -> State: + def transition(self, state: StateType, action: ActType) -> StateType: """Samples from transition distribution.""" @abc.abstractmethod - def reward(self, state: State, action: Action, new_state: State) -> float: + def reward(self, state: StateType, action: ActType, new_state: StateType) -> float: """Computes reward for a given transition.""" @abc.abstractmethod - def terminal(self, state: State, step: int) -> bool: + def terminal(self, state: StateType, step: int) -> bool: """Is the state terminal?""" @abc.abstractmethod - def obs_from_state(self, state: State) -> Observation: + def obs_from_state(self, state: StateType) -> ObsType: """Sample observation for given state.""" - @property - def state_space(self) -> gym.Space: - """State space. Often same as observation_space, but differs in POMDPs.""" - return self._state_space - - @property - def observation_space(self) -> gym.Space: - """Observation space. Return type of reset() and component of step().""" - return self._observation_space - - @property - def action_space(self) -> gym.Space: - """Action space. Parameter type of step().""" - return self._action_space - @property def n_actions_taken(self) -> int: """Number of steps taken so far.""" @@ -93,34 +64,36 @@ def n_actions_taken(self) -> int: return self._n_actions_taken @property - def state(self) -> State: + def state(self) -> StateType: """Current state.""" assert self._cur_state is not None return self._cur_state @state.setter - def state(self, state: State): + def state(self, state: StateType): """Set current state.""" if state not in self.state_space: raise ValueError(f"{state} not in {self.state_space}") self._cur_state = state - def seed(self, seed=None) -> Sequence[int]: - """Set random seed.""" - if seed is None: - # Gym API wants list of seeds to be returned for some reason, so - # generate a seed explicitly in this case - seed = np.random.randint(0, 1 << 31) - self.rand_state = np.random.RandomState(seed) - return [seed] - - def reset(self) -> Observation: + def reset( + self, + *, + seed: int | None = None, + options: dict[str, Any] | None = None, + ) -> tuple[ObsType, dict[str, Any]]: # type: ignore """Reset episode and return initial observation.""" + if options is not None: + raise ValueError("Options not supported.") + + super().reset(seed=seed) self.state = self.initial_state() self._n_actions_taken = 0 - return self.obs_from_state(self.state) + obs = self.obs_from_state(self.state) + info: dict[str, Any] = dict() + return obs, info - def step(self, action: Action) -> Tuple[Observation, float, bool, dict]: + def step(self, action: ActType) -> Tuple[ObsType, float, bool, bool, dict]: """Transition state using given action.""" if self._cur_state is None or self._n_actions_taken is None: raise ValueError("Need to call reset() before first step()") @@ -133,16 +106,30 @@ def step(self, action: Action) -> Tuple[Observation, float, bool, dict]: assert obs in self.observation_space reward = self.reward(old_state, action, self.state) self._n_actions_taken += 1 - done = self.terminal(self.state, self.n_actions_taken) + terminated = self.terminal(self.state, self.n_actions_taken) + truncated = False infos = {"old_state": old_state, "new_state": self._cur_state} - return obs, reward, done, infos + return obs, reward, terminated, truncated, infos + + @property + def rand_state(self) -> np.random.Generator: + """Random state.""" + rand_state = self._np_random + if rand_state is None: + raise ValueError("Need to call reset() before accessing rand_state") + return rand_state -class ExposePOMDPStateWrapper(gym.Wrapper, Generic[State, Observation, Action]): +class ExposePOMDPStateWrapper( + gym.Wrapper[StateType, ActType, ObsType, ActType], + Generic[StateType, ObsType, ActType], +): """A wrapper that exposes the current state of the POMDP as the observation.""" - def __init__(self, env: ResettablePOMDP[State, Observation, Action]) -> None: + env: ResettablePOMDP[StateType, ObsType, ActType] + + def __init__(self, env: ResettablePOMDP[StateType, ObsType, ActType]) -> None: """Build wrapper. Args: @@ -151,51 +138,50 @@ def __init__(self, env: ResettablePOMDP[State, Observation, Action]) -> None: super().__init__(env) self._observation_space = env.state_space - def reset(self) -> State: + def reset( + self, seed: int | None = None, options: dict[str, Any] | None = None + ) -> Tuple[StateType, dict[str, Any]]: """Reset environment and return initial state.""" - self.env.reset() - return self.env.state + _, info = self.env.reset(seed=seed, options=options) + return self.env.state, info - def step(self, action) -> Tuple[State, float, bool, dict]: + def step(self, action) -> Tuple[StateType, float, bool, bool, dict]: """Transition state using given action.""" - obs, reward, done, info = self.env.step(action) - return self.env.state, reward, done, info + _, reward, terminated, truncated, info = self.env.step(action) + return self.env.state, reward, terminated, truncated, info class ResettableMDP( - ResettablePOMDP[State, State, Action], + ResettablePOMDP[StateType, StateType, ActType], abc.ABC, - Generic[State, Action], + Generic[StateType, ActType], ): """ABC for MDPs that are resettable.""" - def __init__( - self, - *, - state_space: gym.Space, - action_space: gym.Space, - ): - """Build resettable MDP. + @property + def observation_space(self) -> spaces.Space[StateType]: + """Observation space.""" + return self.state_space - Args: - state_space: gym.Space containing possible states. - action_space: gym.Space containing possible actions. - """ - super().__init__( - state_space=state_space, - observation_space=state_space, - action_space=action_space, - ) + @observation_space.setter + def observation_space(self, space: spaces.Space[StateType]): + """Set observation space.""" + self.state_space = space - def obs_from_state(self, state: State) -> State: + def obs_from_state(self, state: StateType) -> StateType: """Identity since observation == state in an MDP.""" return state +DiscreteSpaceInt = np.int64 + + # TODO(juan) this does not implement the .render() method, # so in theory it should not be instantiated directly. # Not sure why this is not raising an error? -class BaseTabularModelPOMDP(ResettablePOMDP[int, Observation, int]): +class BaseTabularModelPOMDP( + ResettablePOMDP[DiscreteSpaceInt, ObsType, DiscreteSpaceInt], Generic[ObsType] +): """Base class for tabular environments with known dynamics. This is the general class that also allows subclassing for creating @@ -236,6 +222,8 @@ def __init__( ValueError: `transition_matrix`, `reward_matrix` or `initial_state_dist` have shapes different to specified above. """ + super().__init__() + # The following matrices should conform to the shapes below: # transition matrix: n_states x n_actions x n_states @@ -278,43 +266,42 @@ def __init__( self.horizon = horizon self.initial_state_dist = initial_state_dist - super().__init__( - state_space=self._construct_state_space(), - action_space=self._construct_action_space(), - observation_space=self._construct_observation_space(), - ) - - def _construct_state_space(self) -> gym.Space: - return spaces.Discrete(self.state_dim) - - def _construct_action_space(self) -> gym.Space: - return spaces.Discrete(self.action_dim) - - @abc.abstractmethod - def _construct_observation_space(self) -> gym.Space: - pass # pragma: no cover + self.state_space = spaces.Discrete(self.state_dim) + self.action_space = spaces.Discrete(self.action_dim) - def initial_state(self) -> int: + def initial_state(self) -> DiscreteSpaceInt: """Samples from the initial state distribution.""" - return util.sample_distribution( - self.initial_state_dist, - random=self.rand_state, + return DiscreteSpaceInt( + util.sample_distribution( + self.initial_state_dist, + random=self.rand_state, + ) ) - def transition(self, state: int, action: int) -> int: + def transition( + self, state: DiscreteSpaceInt, action: DiscreteSpaceInt + ) -> DiscreteSpaceInt: """Samples from transition distribution.""" - return util.sample_distribution( - self.transition_matrix[state, action], - random=self.rand_state, + return DiscreteSpaceInt( + util.sample_distribution( + self.transition_matrix[state, action], + random=self.rand_state, + ) ) - def reward(self, state: int, action: int, new_state: int) -> float: + def reward( + self, + state: DiscreteSpaceInt, + action: DiscreteSpaceInt, + new_state: DiscreteSpaceInt, + ) -> float: """Computes reward for a given transition.""" inputs = (state, action, new_state)[: len(self.reward_matrix.shape)] return self.reward_matrix[inputs] - def terminal(self, state: int, n_actions_taken: int) -> bool: + def terminal(self, state: DiscreteSpaceInt, n_actions_taken: int) -> bool: """Checks if state is terminal.""" + del state return self.horizon is not None and n_actions_taken >= self.horizon @property @@ -323,7 +310,7 @@ def feature_matrix(self): # Construct lazily to save memory in algorithms that don't need features. if self._feature_matrix is None: n_states = self.state_space.n - self._feature_matrix = np.eye(n_states) + self._feature_matrix = np.eye(int(n_states)) return self._feature_matrix @property @@ -337,7 +324,12 @@ def action_dim(self) -> int: return self.transition_matrix.shape[1] -class TabularModelPOMDP(BaseTabularModelPOMDP[np.ndarray]): +ObsEntryType = TypeVar( + "ObsEntryType", bound=np.floating[Any] | np.integer[Any], covariant=True +) + + +class TabularModelPOMDP(BaseTabularModelPOMDP[np.ndarray], Generic[ObsEntryType]): """Tabular model POMDP. This class is specifically for environments where observation != state, @@ -349,13 +341,14 @@ class TabularModelPOMDP(BaseTabularModelPOMDP[np.ndarray]): a vector with self.obs_dim entries. """ - observation_matrix: np.ndarray + observation_matrix: npt.NDArray[ObsEntryType] + observation_space: spaces.Box def __init__( self, *, transition_matrix: np.ndarray, - observation_matrix: np.ndarray, + observation_matrix: npt.NDArray[ObsEntryType], reward_matrix: np.ndarray, horizon: Optional[int] = None, initial_state_dist: Optional[np.ndarray] = None, @@ -377,7 +370,6 @@ def __init__( f"observation_matrix.shape[0]: {observation_matrix.shape[0]}", ) - def _construct_observation_space(self) -> gym.Space: min_val: float max_val: float try: @@ -386,14 +378,14 @@ def _construct_observation_space(self) -> gym.Space: except ValueError: min_val = -np.inf max_val = np.inf - return spaces.Box( + self.observation_space = spaces.Box( low=min_val, high=max_val, shape=(self.obs_dim,), dtype=self.obs_dtype, ) - def obs_from_state(self, state: int) -> np.ndarray: + def obs_from_state(self, state: DiscreteSpaceInt) -> npt.NDArray[ObsEntryType]: """Computes observation from state.""" # Copy so it can't be mutated in-place (updates will be reflected in # self.observation_matrix!) @@ -407,12 +399,12 @@ def obs_dim(self) -> int: return self.observation_matrix.shape[1] @property - def obs_dtype(self) -> int: + def obs_dtype(self) -> np.dtype[ObsEntryType]: """Data type of observation vectors (e.g. np.float32).""" return self.observation_matrix.dtype -class TabularModelMDP(BaseTabularModelPOMDP[int]): +class TabularModelMDP(BaseTabularModelPOMDP[DiscreteSpaceInt]): """Tabular model MDP. A tabular model MDP is a tabular MDP where the transition and reward @@ -444,9 +436,6 @@ def __init__( initial_state_dist=initial_state_dist, ) - def obs_from_state(self, state: int) -> int: + def obs_from_state(self, state: DiscreteSpaceInt) -> DiscreteSpaceInt: """Identity since observation == state in an MDP.""" return state - - def _construct_observation_space(self) -> gym.Space: - return self._construct_state_space() diff --git a/src/seals/classic_control.py b/src/seals/classic_control.py index 9854095..7b1d3c9 100644 --- a/src/seals/classic_control.py +++ b/src/seals/classic_control.py @@ -2,14 +2,15 @@ import warnings -from gym import spaces -import gym.envs.classic_control import numpy as np +from gymnasium import spaces +from gymnasium.envs import classic_control + from seals import util -class FixedHorizonCartPole(gym.envs.classic_control.CartPoleEnv): +class FixedHorizonCartPole(classic_control.CartPoleEnv): """Fixed-length variant of CartPole-v1. Reward is 1.0 whenever the CartPole is an "ok" state (i.e. the pole is upright @@ -20,7 +21,7 @@ class FixedHorizonCartPole(gym.envs.classic_control.CartPoleEnv): """ def __init__(self): - """Builds FixedHorizonCartPole, modifying observation_space from Gym parent.""" + """Builds FixedHorizonCartPole, modifying observation_space from gym parent.""" super().__init__() high = [ diff --git a/src/seals/diagnostics/__init__.py b/src/seals/diagnostics/__init__.py index 74be80f..b3c9b28 100644 --- a/src/seals/diagnostics/__init__.py +++ b/src/seals/diagnostics/__init__.py @@ -1,6 +1,6 @@ """Simple diagnostic environments.""" -import gym +import gymnasium as gym gym.register( id="seals/Branching-v0", diff --git a/src/seals/diagnostics/early_term.py b/src/seals/diagnostics/early_term.py index 8c672bb..87287bd 100644 --- a/src/seals/diagnostics/early_term.py +++ b/src/seals/diagnostics/early_term.py @@ -53,7 +53,7 @@ def __init__(self, is_reward_positive: bool = True): reward_matrix=reward_matrix, ) - def terminal(self, state: int, n_actions_taken: int) -> bool: + def terminal(self, state: base_envs.DiscreteSpaceInt, n_actions_taken: int) -> bool: """Returns True if (and only if) in state 2.""" return bool(state == 2) diff --git a/src/seals/diagnostics/init_shift.py b/src/seals/diagnostics/init_shift.py index 086f561..1e47477 100644 --- a/src/seals/diagnostics/init_shift.py +++ b/src/seals/diagnostics/init_shift.py @@ -27,7 +27,7 @@ class InitShiftEnv(base_envs.TabularModelMDP): disambiguate this case. """ - def __init__(self, initial_state: int): + def __init__(self, initial_state: base_envs.DiscreteSpaceInt): """Constructs environment. Args: @@ -63,7 +63,7 @@ def __init__(self, initial_state: int): reward_matrix=reward_matrix, ) - def initial_state(self) -> int: + def initial_state(self) -> base_envs.DiscreteSpaceInt: """Returns initial state defined in constructor.""" return self._initial_state diff --git a/src/seals/diagnostics/largest_sum.py b/src/seals/diagnostics/largest_sum.py index d55021f..688588c 100644 --- a/src/seals/diagnostics/largest_sum.py +++ b/src/seals/diagnostics/largest_sum.py @@ -1,7 +1,7 @@ """Environment testing scalability to high-dimensional tasks.""" -from gym import spaces import numpy as np +from gymnasium import spaces from seals import base_envs @@ -23,12 +23,10 @@ def __init__(self, length: int = 50): Args: length: dimensionality of state space vector. """ + super().__init__() self._length = length - state_space = spaces.Box(low=0.0, high=1.0, shape=(length,)) - super().__init__( - state_space=state_space, - action_space=spaces.Discrete(2), - ) + self.state_space = spaces.Box(low=0.0, high=1.0, shape=(length,)) + self.action_space = spaces.Discrete(2) def terminal(self, state: np.ndarray, n_actions_taken: int) -> bool: """Always returns True, since this task should have a 1-timestep horizon.""" @@ -36,7 +34,7 @@ def terminal(self, state: np.ndarray, n_actions_taken: int) -> bool: def initial_state(self) -> np.ndarray: """Returns vector sampled uniformly in [0, 1]**L.""" - init_state = self.rand_state.rand(self._length) + init_state = self.rand_state.random((self._length,)) return init_state.astype(self.observation_space.dtype) def reward(self, state: np.ndarray, act: int, next_state: np.ndarray) -> float: diff --git a/src/seals/diagnostics/noisy_obs.py b/src/seals/diagnostics/noisy_obs.py index 8588416..a02f9a5 100644 --- a/src/seals/diagnostics/noisy_obs.py +++ b/src/seals/diagnostics/noisy_obs.py @@ -1,7 +1,7 @@ """Environment testing for robustness to noise.""" -from gym import spaces import numpy as np +from gymnasium import spaces from seals import base_envs, util @@ -23,6 +23,8 @@ def __init__(self, *, size: int = 5, noise_length: int = 20): size: width and height of gridworld. noise_length: dimension of noise vector in observation. """ + super().__init__() + self._size = size self._noise_length = noise_length self._goal = np.array([self._size // 2, self._size // 2]) @@ -34,14 +36,12 @@ def __init__(self, *, size: int = 5, noise_length: int = 20): ([size - 1, size - 1], np.full(self._noise_length, np.inf)), # type: ignore ) - super().__init__( - state_space=spaces.MultiDiscrete([size, size]), - action_space=spaces.Discrete(5), - observation_space=spaces.Box( - low=obs_box_low, - high=obs_box_high, - dtype=np.float32, - ), + self.state_space = spaces.MultiDiscrete([size, size]) + self.action_space = spaces.Discrete(5) + self.observation_space = spaces.Box( + low=obs_box_low, + high=obs_box_high, + dtype=np.float32, ) def terminal(self, state: np.ndarray, n_actions_taken: int) -> bool: @@ -52,7 +52,7 @@ def initial_state(self) -> np.ndarray: """Returns one of the grid's corners.""" n = self._size corners = np.array([[0, 0], [n - 1, 0], [0, n - 1], [n - 1, n - 1]]) - return corners[self.rand_state.randint(4)] + return corners[self.rand_state.integers(4)] def reward(self, state: np.ndarray, action: int, new_state: np.ndarray) -> float: """Returns +1.0 reward if state is the goal and 0.0 otherwise.""" diff --git a/src/seals/diagnostics/parabola.py b/src/seals/diagnostics/parabola.py index 07c9b8b..053eff1 100644 --- a/src/seals/diagnostics/parabola.py +++ b/src/seals/diagnostics/parabola.py @@ -1,7 +1,7 @@ """Environment testing for generalization in continuous spaces.""" -from gym import spaces import numpy as np +from gymnasium import spaces from seals import base_envs @@ -24,16 +24,15 @@ def __init__(self, x_step: float = 0.05, bounds: float = 5): bounds: limits coordinates, useful for keeping rewards in a small bounded range. """ + super().__init__() self._x_step = x_step self._bounds = bounds state_high = np.array([bounds, bounds, 1.0, 1.0, 1.0]) state_low = (-1) * state_high - super().__init__( - state_space=spaces.Box(low=state_low, high=state_high), - action_space=spaces.Box(low=(-2) * bounds, high=2 * bounds, shape=()), - ) + self.state_space = spaces.Box(low=state_low, high=state_high) + self.action_space = spaces.Box(low=(-2) * bounds, high=2 * bounds, shape=()) def terminal(self, state: int, n_actions_taken: int) -> bool: """Always returns False.""" @@ -41,7 +40,7 @@ def terminal(self, state: int, n_actions_taken: int) -> bool: def initial_state(self) -> np.ndarray: """Get state by sampling a random parabola.""" - a, b, c = -1 + 2 * self.rand_state.rand(3) + a, b, c = -1 + 2 * self.rand_state.random((3,)) x, y = 0, c return np.array([x, y, a, b, c], dtype=self.state_space.dtype) diff --git a/src/seals/diagnostics/proc_goal.py b/src/seals/diagnostics/proc_goal.py index e534b72..706a939 100644 --- a/src/seals/diagnostics/proc_goal.py +++ b/src/seals/diagnostics/proc_goal.py @@ -1,7 +1,7 @@ """Large gridworld with random agent and goal position.""" -from gym import spaces import numpy as np +from gymnasium import spaces from seals import base_envs, util @@ -28,13 +28,11 @@ def __init__(self, bounds: int = 100, distance: int = 10): generalization harder. distance: initial distance between agent and goal. """ + super().__init__() self._bounds = bounds self._distance = distance - - super().__init__( - state_space=spaces.Box(low=-np.inf, high=np.inf, shape=(4,)), - action_space=spaces.Discrete(5), - ) + self.state_space = spaces.Box(low=-np.inf, high=np.inf, shape=(4,)) + self.action_space = spaces.Discrete(5) def terminal(self, state: np.ndarray, n_actions_taken: int) -> bool: """Always returns False.""" @@ -42,11 +40,11 @@ def terminal(self, state: np.ndarray, n_actions_taken: int) -> bool: def initial_state(self) -> np.ndarray: """Samples random agent position and random goal.""" - pos = self.rand_state.randint(low=-self._bounds, high=self._bounds, size=(2,)) + pos = self.rand_state.integers(low=-self._bounds, high=self._bounds, size=(2,)) - x_dist = self.rand_state.randint(self._distance) + x_dist = self.rand_state.integers(self._distance) y_dist = self._distance - x_dist - random_signs = 2 * self.rand_state.randint(2, size=2) - 1 + random_signs = 2 * self.rand_state.integers(2, size=2) - 1 goal = pos + random_signs * (x_dist, y_dist) return np.concatenate([pos, goal]).astype(self.observation_space.dtype) diff --git a/src/seals/diagnostics/random_trans.py b/src/seals/diagnostics/random_trans.py index 09eb6b4..13dae26 100644 --- a/src/seals/diagnostics/random_trans.py +++ b/src/seals/diagnostics/random_trans.py @@ -46,7 +46,7 @@ def __init__( """ # this generator is ONLY for constructing the MDP, not for controlling # random outcomes during rollouts - rand_gen = np.random.RandomState(generator_seed) + rand_gen = np.random.default_rng(generator_seed) if random_obs: if obs_dim is None: @@ -87,7 +87,7 @@ def make_random_trans_mat( n_states, n_actions, max_branch_factor, - rand_state: Optional[np.random.RandomState] = None, + rand_state: Optional[np.random.Generator] = None, ) -> np.ndarray: """Make a 'random' transition matrix. @@ -110,14 +110,14 @@ def make_random_trans_mat( of transitioning to `next_s` after taking action `a` in state `s`. """ if rand_state is None: - rand_state = np.random.RandomState() + rand_state = np.random.default_rng() assert rand_state is not None out_mat = np.zeros((n_states, n_actions, n_states), dtype="float32") for start_state in range(n_states): for action in range(n_actions): # uniformly sample a number of successors in [1,max_branch_factor] # for this action - successors = rand_state.randint(1, max_branch_factor + 1) + successors = rand_state.integers(1, max_branch_factor + 1) next_states = rand_state.choice( n_states, size=(successors,), @@ -133,7 +133,7 @@ def make_random_trans_mat( def make_random_state_dist( n_avail: int, n_states: int, - rand_state: Optional[np.random.RandomState] = None, + rand_state: Optional[np.random.Generator] = None, ) -> np.ndarray: """Make a random initial state distribution over n_states. @@ -152,7 +152,7 @@ def make_random_state_dist( ValueError: If `n_avail` is not in the range `(0, n_states]`. """ # noqa: DAR402 if rand_state is None: - rand_state = np.random.RandomState() + rand_state = np.random.default_rng() assert rand_state is not None assert 0 < n_avail <= n_states init_dist = np.zeros((n_states,)) @@ -168,7 +168,7 @@ def make_obs_mat( n_states: int, is_random: bool, obs_dim: Optional[int] = None, - rand_state: Optional[np.random.RandomState] = None, + rand_state: Optional[np.random.Generator] = None, ) -> np.ndarray: """Makes an observation matrix with a single observation for each state. @@ -179,7 +179,7 @@ def make_obs_mat( If `False`, are unique one-hot vectors for each state. obs_dim (int or NoneType): Must be `None` if `is_random == False`. Otherwise, this must be set to the size of the random vectors. - rand_state (np.random.RandomState): Random number generator. + rand_state (np.random.Generator): Random number generator. Returns: A matrix of shape `(n_states, obs_dim if is_random else n_states)`. @@ -188,7 +188,7 @@ def make_obs_mat( ValueError: If ``is_random == False`` and ``obs_dim is not None``. """ if rand_state is None: - rand_state = np.random.RandomState() + rand_state = np.random.default_rng() assert rand_state is not None if is_random: if obs_dim is None: diff --git a/src/seals/diagnostics/sort.py b/src/seals/diagnostics/sort.py index d1c4313..6b1a88e 100644 --- a/src/seals/diagnostics/sort.py +++ b/src/seals/diagnostics/sort.py @@ -1,7 +1,7 @@ """Environment to sort a list using swap actions.""" -from gym import spaces import numpy as np +from gymnasium import spaces from seals import base_envs @@ -21,10 +21,9 @@ def __init__(self, length: int = 4): """ self._length = length - super().__init__( - state_space=spaces.Box(low=0, high=1.0, shape=(length,)), - action_space=spaces.MultiDiscrete([length, length]), - ) + super().__init__() + self.state_space = spaces.Box(low=0, high=1.0, shape=(length,)) + self.action_space = spaces.MultiDiscrete([length, length]) def terminal(self, state: np.ndarray, n_actions_taken: int) -> bool: """Always returns False.""" @@ -42,6 +41,7 @@ def reward( new_state: np.ndarray, ) -> float: """Rewards fully sorted lists, and new correct positions.""" + del action # This is not meant to be a potential shaping in the formal sense, # as it changes the trajectory returns (since we do not return # a fixed-potential state at termination). diff --git a/src/seals/mujoco.py b/src/seals/mujoco.py index 5cd2c5f..0e539ab 100644 --- a/src/seals/mujoco.py +++ b/src/seals/mujoco.py @@ -2,7 +2,7 @@ import functools -from gym.envs.mujoco import ( +from gymnasium.envs.mujoco import ( ant_v3, half_cheetah_v3, hopper_v3, diff --git a/src/seals/testing/envs.py b/src/seals/testing/envs.py index a927a50..58bb7e9 100644 --- a/src/seals/testing/envs.py +++ b/src/seals/testing/envs.py @@ -17,12 +17,12 @@ Tuple, ) -import gym +import gymnasium as gym import numpy as np -Step = Tuple[Any, Optional[float], bool, Mapping[str, Any]] +Step = Tuple[Any, Optional[float], bool, bool, Mapping[str, Any]] Rollout = Sequence[Step] -"""A sequence of 4-tuples (obs, rew, done, info) as returned by `get_rollout`.""" +"""A sequence of 4-tuples (obs, rew, terminated, truncated, info) as returned by `get_rollout`.""" def make_env_fixture( @@ -95,9 +95,9 @@ def get_rollout(env: gym.Env, actions: Iterable[Any]) -> Rollout: actions: the actions to perform. Returns: - A sequence of 4-tuples (obs, rew, done, info). + A sequence of 4-tuples (obs, rew, terminated, truncated, info). """ - ret: List[Step] = [(env.reset(), None, False, {})] + ret: List[Step] = [(env.reset(), None, False, False, {})] for act in actions: ret.append(env.step(act)) return ret @@ -110,11 +110,12 @@ def assert_equal_rollout(rollout_a: Rollout, rollout_b: Rollout) -> None: AssertionError if they are not equal. """ for step_a, step_b in zip(rollout_a, rollout_b): - ob_a, rew_a, done_a, info_a = step_a - ob_b, rew_b, done_b, info_b = step_b + ob_a, rew_a, terminated_a, truncated_a, info_a = step_a + ob_b, rew_b, terminated_b, truncated_b, info_b = step_b np.testing.assert_equal(ob_a, ob_b) assert rew_a == rew_b - assert done_a == done_b + assert terminated_a == terminated_b + assert truncated_a == truncated_b np.testing.assert_equal(info_a, info_b) @@ -155,13 +156,13 @@ def test_seed( env.action_space.seed(0) actions = [env.action_space.sample() for _ in range(rollout_len)] # With the same seed, should always get the same result - seeds = env.seed(42) + seeds = env.reset(seed=42) # output of env.seed should be a list, but atari environments return a tuple. assert isinstance(seeds, (list, tuple)) assert len(seeds) > 0 rollout_a = get_rollout(env, actions) - env.seed(42) + env.reset(seed=42) rollout_b = get_rollout(env, actions) assert_equal_rollout(rollout_a, rollout_b) @@ -171,9 +172,9 @@ def test_seed( # seeds should produce the same starting state. def different_seeds_same_rollout(seed1, seed2): new_actions = [env.action_space.sample() for _ in range(rollout_len)] - env.seed(seed1) + env.reset(seed=seed1) new_rollout_1 = get_rollout(env, new_actions) - env.seed(seed2) + env.reset(seed=seed2) new_rollout_2 = get_rollout(env, new_actions) return has_same_observations(new_rollout_1, new_rollout_2) @@ -192,15 +193,16 @@ def _check_obs(obs: np.ndarray, obs_space: gym.Space) -> None: assert obs in obs_space -def _sample_and_check(env: gym.Env, obs_space: gym.Space) -> bool: +def _sample_and_check(env: gym.Env, obs_space: gym.Space) -> Tuple[bool, bool]: """Sample from env and check return value is of valid type.""" act = env.action_space.sample() - obs, rew, done, info = env.step(act) + obs, rew, terminated, truncated, info = env.step(act) _check_obs(obs, obs_space) assert isinstance(rew, float) - assert isinstance(done, bool) + assert isinstance(terminated, bool) + assert isinstance(truncated, bool) assert isinstance(info, dict) - return done + return terminated, truncated def _is_mujoco_env(env: gym.Env) -> bool: @@ -228,16 +230,17 @@ def test_rollout_schema( AssertionError if test fails. """ obs_space = env.observation_space - obs = env.reset() + obs, _ = env.reset() _check_obs(obs, obs_space) + terminated = False for _ in range(max_steps): - done = _sample_and_check(env, obs_space) - if done: + terminated, _ = _sample_and_check(env, obs_space) + if terminated: break if check_episode_ends: - assert done, "did not get to end of episode" + assert terminated, "did not get to end of episode" for _ in range(steps_after_done): _sample_and_check(env, obs_space) @@ -352,5 +355,5 @@ def step(self, action): t, self.timestep = self.timestep, self.timestep + 1 obs = np.array(t, dtype=self.observation_space.dtype) rew = t * 10.0 - done = t == self.episode_length - return obs, rew, done, {} + terminated = t == self.episode_length + return obs, rew, terminated, False, {} diff --git a/src/seals/util.py b/src/seals/util.py index 60c0323..a7ffb1e 100644 --- a/src/seals/util.py +++ b/src/seals/util.py @@ -1,13 +1,17 @@ """Miscellaneous utilities.""" from dataclasses import dataclass -from typing import List, Optional, Sequence, Tuple, Union +from typing import Any, Generic, List, Optional, Sequence, SupportsFloat, Tuple, Union -import gym +import gymnasium as gym import numpy as np +import numpy.typing as npt +from gymnasium.core import ActType, ObsType, WrapperActType, WrapperObsType -class AutoResetWrapper(gym.Wrapper): +class AutoResetWrapper( + gym.Wrapper, Generic[WrapperObsType, WrapperActType, ObsType, ActType] +): """Hides done=True and auto-resets at the end of each episode. Depending on the flag 'discard_terminal_observation', either discards the terminal @@ -37,7 +41,9 @@ def __init__(self, env, discard_terminal_observation=True, reset_reward=0.0): self.reset_reward = reset_reward self.previous_done = False # Whether the previous step returned done=True. - def step(self, action): + def step( + self, action: WrapperActType + ) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]: """When done=True, returns done=False, then reset depending on flag. Depending on whether we are discarding the terminal observation, @@ -50,7 +56,9 @@ def step(self, action): else: return self._step_pad(action) - def _step_pad(self, action): + def _step_pad( + self, action: WrapperActType + ) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]: """When done=True, return done=False instead and return the terminal obs. The agent will then usually be asked to perform an action based on @@ -67,26 +75,31 @@ def _step_pad(self, action): """ if self.previous_done: self.previous_done = False + reset_obs, reset_info_dict = self.env.reset() + info = {"reset_info_dict": reset_info_dict} # This transition will only reset the environment, the action is ignored. - return self.env.reset(), self.reset_reward, False, {} + return reset_obs, self.reset_reward, False, False, info - obs, rew, done, info = self.env.step(action) - if done: + obs, rew, terminated, truncated, info = self.env.step(action) + if terminated: self.previous_done = True - return obs, rew, False, info + return obs, rew, False, truncated, info - def _step_discard(self, action): + def _step_discard( + self, action: WrapperActType + ) -> tuple[ObsType, SupportsFloat, bool, bool, dict[str, Any]]: """When done=True, returns done=False instead and automatically resets. When an automatic reset happens, the observation from reset is returned, and the overridden observation is stored in `info["terminal_observation"]`. """ - obs, rew, done, info = self.env.step(action) - if done: + obs, rew, terminated, truncated, info = self.env.step(action) + if terminated: info["terminal_observation"] = obs - obs = self.env.reset() - return obs, rew, False, info + obs, reset_info_dict = self.env.reset() + info["reset_info_dict"] = reset_info_dict + return obs, rew, False, truncated, info @dataclass @@ -100,7 +113,10 @@ class BoxRegion: MaskedRegionSpecifier = List[BoxRegion] -class MaskScoreWrapper(gym.Wrapper): +class MaskScoreWrapper( + gym.Wrapper[npt.NDArray, ActType, npt.NDArray, ActType], + Generic[ActType], +): """Mask a list of box-shaped regions in the observation to hide reward info. Intended for environments whose observations are raw pixels (like Atari @@ -130,19 +146,22 @@ def __init__( super().__init__(env) self.fill_value = np.array(fill_value, env.observation_space.dtype) + assert env.observation_space.shape is not None self.mask = np.ones(env.observation_space.shape, dtype=bool) for r in score_regions: if r.x[0] >= r.x[1] or r.y[0] >= r.y[1]: raise ValueError('Invalid region: "x" and "y" must be increasing.') self.mask[r.x[0] : r.x[1], r.y[0] : r.y[1]] = 0 - def _mask_obs(self, obs): + def _mask_obs(self, obs) -> npt.NDArray: return np.where(self.mask, obs, self.fill_value) - def step(self, action): - """Returns (obs, rew, done, info) with masked obs.""" - obs, rew, done, info = self.env.step(action) - return self._mask_obs(obs), rew, done, info + def step( + self, action: ActType + ) -> tuple[npt.NDArray, SupportsFloat, bool, bool, dict[str, Any]]: + """Returns (obs, rew, terminated, truncated, info) with masked obs.""" + obs, rew, terminated, truncated, info = self.env.step(action) + return self._mask_obs(obs), rew, terminated, truncated, info def reset(self, **kwargs): """Returns masked reset observation.""" @@ -174,9 +193,9 @@ def reset(self): return super().reset().astype(self.dtype) def step(self, action): - """Returns (obs, rew, done, info) with obs cast to self.dtype.""" - obs, rew, done, info = super().step(action) - return obs.astype(self.dtype), rew, done, info + """Returns (obs, rew, terminated, truncated, info) with obs cast to self.dtype.""" + obs, rew, terminated, truncated, info = super().step(action) + return obs.astype(self.dtype), rew, terminated, truncated, info class AbsorbAfterDoneWrapper(gym.Wrapper): @@ -228,8 +247,10 @@ def step(self, action): `rew` depend on initialization arguments. `info` is always an empty dictionary. """ if not self.at_absorb_state: - inner_obs, inner_rew, done, inner_info = self.env.step(action) - if done: + inner_obs, inner_rew, terminated, truncated, inner_info = self.env.step( + action + ) + if terminated: # Initialize the artificial absorb state, which we will repeatedly use # starting on the next call to `step()`. self.at_absorb_state = True @@ -245,26 +266,27 @@ def step(self, action): obs = self.absorb_obs_this_episode rew = self.absorb_reward info = {} + truncated = False - return obs, rew, False, info + return obs, rew, False, truncated, info def make_env_no_wrappers(env_name: str, **kwargs) -> gym.Env: - """Gym sometimes wraps envs in TimeLimit before returning from gym.make(). + """Gym sometimes wraps envs in TimeLimit before returning from gymnasium.make(). This helper method builds directly from spec to avoid this wrapper. """ - return gym.envs.registry.env_specs[env_name].make(**kwargs) + return gym.spec(env_name).make(**kwargs) def get_gym_max_episode_steps(env_name: str) -> Optional[int]: """Get the `max_episode_steps` attribute associated with a gym Spec.""" - return gym.envs.registry.env_specs[env_name].max_episode_steps + return gym.spec(env_name).max_episode_steps def sample_distribution( p: np.ndarray, - random: np.random.RandomState, + random: np.random.Generator, ) -> int: """Samples an integer with probabilities given by p.""" return random.choice(np.arange(len(p)), p=p) diff --git a/tests/test_base_env.py b/tests/test_base_env.py index 59a3e32..c91fe22 100644 --- a/tests/test_base_env.py +++ b/tests/test_base_env.py @@ -4,10 +4,9 @@ so the tests in this file focus on features unique to classes in `base_envs`. """ -import gym +import gymnasium as gym import numpy as np import pytest - from seals import base_envs from seals.testing import envs @@ -19,9 +18,9 @@ def __init__(self): """Build environment.""" nS = 3 nA = 2 - transition_matrix = np.random.rand(nS, nA, nS) + transition_matrix = np.random.random((nS, nA, nS)) transition_matrix /= transition_matrix.sum(axis=2)[:, :, None] - reward_matrix = np.random.rand(nS) + reward_matrix = np.random.random((nS,)) super().__init__( transition_matrix=transition_matrix, reward_matrix=reward_matrix, diff --git a/tests/test_envs.py b/tests/test_envs.py index b2e86d4..4f59762 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -2,11 +2,10 @@ from typing import List -import gym -from gym.envs import registration +import gymnasium as gym import pytest - import seals # noqa: F401 required for env registration +from gymnasium.envs import registration from seals.atari import SCORE_REGIONS, _get_score_region, _seals_name, make_atari_env from seals.testing import envs diff --git a/tests/test_mujoco_rl.py b/tests/test_mujoco_rl.py index b0e362c..55f66b6 100644 --- a/tests/test_mujoco_rl.py +++ b/tests/test_mujoco_rl.py @@ -2,7 +2,7 @@ from typing import Tuple -import gym +import gymnasium as gym import pytest import stable_baselines3 from stable_baselines3.common import evaluation diff --git a/tests/test_util.py b/tests/test_util.py index 5829ebb..9ed8ceb 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -2,10 +2,9 @@ import collections -import gym +import gymnasium as gym import numpy as np import pytest - from seals import GYM_ATARI_ENV_SPECS, util @@ -22,11 +21,11 @@ def test_mask_score_wrapper_enforces_spec(): def test_sample_distribution(): """Test util.sample_distribution.""" distr_size = 5 - distr = np.random.rand(distr_size) + distr = np.random.random((distr_size,)) distr /= distr.sum() n_samples = 1000 - rng = np.random.RandomState() + rng = np.random.default_rng() sample_count = collections.Counter( util.sample_distribution(distr, rng) for _ in range(n_samples) ) @@ -39,8 +38,8 @@ def test_sample_distribution(): # Same seed gives same samples assert all( - util.sample_distribution(distr, random=np.random.RandomState(seed)) - == util.sample_distribution(distr, random=np.random.RandomState(seed)) + util.sample_distribution(distr, random=np.random.default_rng(seed)) + == util.sample_distribution(distr, random=np.random.default_rng(seed)) for seed in range(20) ) diff --git a/tests/test_wrappers.py b/tests/test_wrappers.py index 06fac22..a9c1906 100644 --- a/tests/test_wrappers.py +++ b/tests/test_wrappers.py @@ -2,7 +2,6 @@ import numpy as np import pytest - from seals import util from seals.testing import envs @@ -31,10 +30,11 @@ def test_auto_reset_wrapper_pad(episode_length=3, n_steps=100, n_manual_reset=2) next_episode_end = episode_length for t in range(1, n_steps + 1): act = env.action_space.sample() - obs, rew, done, info = env.step(act) + obs, rew, terminated, truncated, info = env.step(act) # AutoResetWrapper overrides all done signals. - assert done is False + assert terminated is False + assert truncated is False if t == next_episode_end: # Unlike the AutoResetWrapper that discards terminal observations, @@ -80,11 +80,12 @@ def test_auto_reset_wrapper_discard(episode_length=3, n_steps=100, n_manual_rese for t in range(1, n_steps + 1): act = env.action_space.sample() - obs, rew, done, info = env.step(act) + obs, rew, terminated, truncated, info = env.step(act) expected_obs = t % episode_length assert obs == expected_obs - assert done is False + assert terminated is False + assert truncated is False if expected_obs == 0: # End of episode assert info.get("terminal_observation", None) == episode_length @@ -113,8 +114,9 @@ def test_absorb_repeat_custom_state( env.reset() for t in range(1, n_steps + 1): act = env.action_space.sample() - obs, rew, done, _ = env.step(act) - assert done is False + obs, rew, terminated, truncated, _ = env.step(act) + assert terminated is False + assert truncated is False if t > episode_length: expected_obs = absorb_obs expected_rew = absorb_reward @@ -134,8 +136,9 @@ def test_absorb_repeat_final_state(episode_length=6, n_steps=100, n_manual_reset env.reset() for t in range(1, n_steps + 1): act = env.action_space.sample() - obs, rew, done, _ = env.step(act) - assert done is False + obs, rew, terminated, truncated, _ = env.step(act) + assert terminated is False + assert truncated is False if t > episode_length: expected_obs = episode_length expected_rew = -1 @@ -161,8 +164,9 @@ def test_obs_cast(dtype: np.dtype, episode_length: int = 5): assert obs == 0 for t in range(1, episode_length + 1): act = env.action_space.sample() - obs, rew, done, _ = env.step(act) - assert done == (t == episode_length) + obs, rew, terminated, truncated, _ = env.step(act) + assert terminated == (t == episode_length) + assert truncated is False assert obs.dtype == dtype assert obs == t assert rew == t * 10.0