diff --git a/.github/ISSUE_TEMPLATE/custom_env.yml b/.github/ISSUE_TEMPLATE/custom_env.yml index cf624c03b..f90210858 100644 --- a/.github/ISSUE_TEMPLATE/custom_env.yml +++ b/.github/ISSUE_TEMPLATE/custom_env.yml @@ -49,15 +49,16 @@ body: self.observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=(14,)) self.action_space = spaces.Box(low=-1, high=1, shape=(6,)) - def reset(self): - return self.observation_space.sample() + def reset(self, seed=None): + return self.observation_space.sample(), {} def step(self, action): obs = self.observation_space.sample() reward = 1.0 - done = False + terminated = False + truncated = False info = {} - return obs, reward, done, info + return obs, reward, terminated, truncated, info env = CustomEnv() check_env(env) diff --git a/Dockerfile b/Dockerfile index 8dfbbbf4c..712a795d1 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,41 +1,25 @@ ARG PARENT_IMAGE FROM $PARENT_IMAGE ARG PYTORCH_DEPS=cpuonly -ARG PYTHON_VERSION=3.7 +ARG PYTHON_VERSION=3.8 +ARG MAMBA_DOCKERFILE_ACTIVATE=1 # (otherwise python will not be found) -RUN apt-get update && apt-get install -y --no-install-recommends \ - build-essential \ - cmake \ - git \ - curl \ - ca-certificates \ - libjpeg-dev \ - libpng-dev \ - libglib2.0-0 && \ - rm -rf /var/lib/apt/lists/* - -# Install Anaconda and dependencies -RUN curl -o ~/miniconda.sh https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh && \ - chmod +x ~/miniconda.sh && \ - ~/miniconda.sh -b -p /opt/conda && \ - rm ~/miniconda.sh && \ - /opt/conda/bin/conda install -y python=$PYTHON_VERSION numpy pyyaml scipy ipython mkl mkl-include && \ - /opt/conda/bin/conda install -y pytorch $PYTORCH_DEPS -c pytorch && \ - /opt/conda/bin/conda clean -ya -ENV PATH /opt/conda/bin:$PATH - -ENV CODE_DIR /root/code +ENV CODE_DIR /home/$MAMBA_USER # Copy setup file only to install dependencies -COPY ./setup.py ${CODE_DIR}/stable-baselines3/setup.py -COPY ./stable_baselines3/version.txt ${CODE_DIR}/stable-baselines3/stable_baselines3/version.txt +COPY --chown=$MAMBA_USER:$MAMBA_USER ./setup.py ${CODE_DIR}/stable-baselines3/setup.py +COPY --chown=$MAMBA_USER:$MAMBA_USER ./stable_baselines3/version.txt ${CODE_DIR}/stable-baselines3/stable_baselines3/version.txt + +# Install micromamba env and dependencies +RUN micromamba install -n base -y python=$PYTHON_VERSION \ + pytorch $PYTORCH_DEPS -c conda-forge -c pytorch -c nvidia && \ + micromamba clean --all --yes -RUN \ - cd ${CODE_DIR}/stable-baselines3 3&& \ +RUN cd ${CODE_DIR}/stable-baselines3 && \ pip install -e .[extra,tests,docs] && \ # Use headless version for docker pip uninstall -y opencv-python && \ pip install opencv-python-headless && \ - rm -rf $HOME/.cache/pip + pip cache purge CMD /bin/bash diff --git a/Makefile b/Makefile index 29ac5e70e..4f477d066 100644 --- a/Makefile +++ b/Makefile @@ -10,6 +10,12 @@ pytype: mypy: mypy ${LINT_PATHS} +missing-annotations: + mypy --disallow-untyped-calls --disallow-untyped-defs --ignore-missing-imports stable_baselines3 + +# missing docstrings +# pylint -d R,C,W,E -e C0116 stable_baselines3 -j 4 + type: pytype mypy lint: diff --git a/docs/conda_env.yml b/docs/conda_env.yml index 98a550820..7b89ba92b 100644 --- a/docs/conda_env.yml +++ b/docs/conda_env.yml @@ -4,11 +4,11 @@ channels: - defaults dependencies: - cpuonly=1.0=0 - - pip=21.1 + - pip=22.1.1 - python=3.7 - - pytorch=1.11=py3.7_cpu_0 + - pytorch=1.11.0=py3.7_cpu_0 - pip: - - gym==0.21 + - gym==0.26 - cloudpickle - opencv-python-headless - pandas diff --git a/docs/guide/checking_nan.rst b/docs/guide/checking_nan.rst index ef3762c41..7395fbd8b 100644 --- a/docs/guide/checking_nan.rst +++ b/docs/guide/checking_nan.rst @@ -129,7 +129,7 @@ It will monitor the actions, observations, and rewards, indicating what action o def reset(self): return [0.0] - def render(self, mode="human", close=False): + def render(self, close=False): pass # Create environment diff --git a/docs/guide/custom_env.rst b/docs/guide/custom_env.rst index d2878c376..2392bbb31 100644 --- a/docs/guide/custom_env.rst +++ b/docs/guide/custom_env.rst @@ -54,7 +54,7 @@ That is to say, your environment must implement the following methods (and inher ... return observation # reward, done, info can't be included - def render(self, mode="human"): + def render(self): ... def close(self): diff --git a/docs/guide/examples.rst b/docs/guide/examples.rst index a3f1dc6f0..6ed72d4cc 100644 --- a/docs/guide/examples.rst +++ b/docs/guide/examples.rst @@ -467,19 +467,19 @@ The parking env is a goal-conditioned continuous control task, in which the vehi # HER must be loaded with the env model = SAC.load("her_sac_highway", env=env) - obs = env.reset() + obs, info = env.reset() # Evaluate the agent episode_reward = 0 for _ in range(100): action, _ = model.predict(obs, deterministic=True) - obs, reward, done, info = env.step(action) + obs, reward, terminated, truncated, info = env.step(action) env.render() episode_reward += reward - if done or info.get("is_success", False): + if terminated or truncated or info.get("is_success", False): print("Reward:", episode_reward, "Success?", info.get("is_success", False)) episode_reward = 0.0 - obs = env.reset() + obs, info = env.reset() Learning Rate Schedule diff --git a/docs/guide/quickstart.rst b/docs/guide/quickstart.rst index 5d1055ac9..e809f1ba0 100644 --- a/docs/guide/quickstart.rst +++ b/docs/guide/quickstart.rst @@ -4,6 +4,12 @@ Getting Started =============== +.. note:: + + Stable-Baselines3 (SB3) uses :ref:`vectorized environments (VecEnv) ` internally. + Please read the associated section to learn more about its features and differences compared to a single Gym environment. + + Most of the library tries to follow a sklearn-like syntax for the Reinforcement Learning algorithms. Here is a quick example of how to train and run A2C on a CartPole environment: diff --git a/docs/guide/vec_envs.rst b/docs/guide/vec_envs.rst index d84781122..ea99444d1 100644 --- a/docs/guide/vec_envs.rst +++ b/docs/guide/vec_envs.rst @@ -44,6 +44,58 @@ SubprocVecEnv ✔️ ✔️ ✔️ ✔️ ✔️ For more information, see Python's `multiprocessing guidelines `_. +VecEnv API vs Gym API +--------------------- + +For consistency across Stable-Baselines3 (SB3) versions and because of its special requirements and features, +SB3 VecEnv API is not the same as Gym API. +SB3 VecEnv API is actually close to Gym 0.21 API but differs to Gym 0.26+ API: + +- the ``reset()`` method only returns the observation (``obs = vec_env.reset()``) and not a tuple, the info at reset are stored in ``vec_env.reset_infos``. + +- only the initial call to ``vec_env.reset()`` is required, environments are reset automatically afterward (and ``reset_infos`` is updated automatically). + +- the ``vec_env.step(actions)`` method expects an array as input + (with a batch size corresponding to the number of environments) and returns a 4-tuple (and not a 5-tuple): ``obs, rewards, dones, infos`` instead of ``obs, reward, terminated, truncated, info`` + where ``dones = terminated or truncated`` (for each env). + ``obs, rewards, dones`` are numpy arrays with shape ``(n_envs, shape_for_single_env)`` (so with a batch dimension). + Additional information is passed via the ``infos`` value which is a list of dictionaries. + +- at the end of an episode, ``infos[env_idx]["TimeLimit.truncated"] = truncated and not terminated`` + tells the user if an episode was truncated or not: + you should bootstrap if ``infos[env_idx]["TimeLimit.truncated"] is True`` (episode over due to a timeout/truncation) + or ``dones[env_idx] is False`` (episode not finished). + Note: compared to Gym 0.26+ ``infos[env_idx]["TimeLimit.truncated"]`` and ``terminated`` `are mutually exclusive `_. + The conversion from SB3 to Gym API is + + .. code-block:: python + + # done is True at the end of an episode + # dones[env_idx] = terminated[env_idx] or truncated[env_idx] + # In SB3, truncated and terminated are mutually exclusive + # infos[env_idx]["TimeLimit.truncated"] = truncated and not terminated + # terminated[env_idx] tells you whether you should bootstrap or not: + # when the episode has not ended or when the termination was a timeout/truncation + terminated[env_idx] = dones[env_idx] and not infos[env_idx]["TimeLimit.truncated"] + should_bootstrap[env_idx] = not terminated[env_idx] + + +- at the end of an episode, because the environment resets automatically, + we provide ``infos[env_idx]["terminal_observation"]`` which contains the last observation + of an episode (and can be used when bootstrapping, see note in the previous section) + +- to overcome the current Gymnasium limitation (only one render mode allowed per env instance, see `issue #100 `_), + we recommend using ``render_mode="rgb_array"`` since we can both have the image as a numpy array and display it with OpenCV. + if no mode is passed or ``mode="rgb_array"`` is passed when calling ``vec_env.render`` then we use the default mode, otherwise, we use the OpenCV display. + Note that if ``render_mode != "rgb_array"``, you can only call ``vec_env.render()`` (without argument or with ``mode=env.render_mode``). + +- the ``reset()`` method doesn't take any parameter. If you want to seed the pseudo-random generator, + you should call ``vec_env.seed(seed=seed)`` and ``obs = vec_env.reset()`` afterward. + +- methods and attributes of the underlying Gym envs can be accessed, called and set using ``vec_env.get_attr("attribute_name")``, + ``vec_env.env_method("method_name", args1, args2, kwargs1=kwargs1)`` and ``vec_env.set_attr("attribute_name", new_value)``. + + Vectorized Environments Wrappers -------------------------------- diff --git a/docs/misc/changelog.rst b/docs/misc/changelog.rst index 5c4dc426a..ada1c0e00 100644 --- a/docs/misc/changelog.rst +++ b/docs/misc/changelog.rst @@ -97,6 +97,7 @@ Breaking Changes: please use an ``EvalCallback`` instead - Removed deprecated ``sde_net_arch`` parameter - Removed ``ret`` attributes in ``VecNormalize``, please use ``returns`` instead +- Switched minimum Gym version to 0.26 (@carlosluis, @arjun-kg, @tlpss) - ``VecNormalize`` now updates the observation space when normalizing images New Features: @@ -280,6 +281,7 @@ Breaking Changes: New Features: ^^^^^^^^^^^^^ +- ``noop_max`` and ``frame_skip`` are now allowed to be equal to zero when using ``AtariWrapper`` `SB3-Contrib`_ ^^^^^^^^^^^^^^ @@ -305,6 +307,7 @@ Deprecations: Others: ^^^^^^^ - Upgraded to Python 3.7+ syntax using ``pyupgrade`` +- Updated docker base image to Ubuntu 20.04 and cuda 11.3 - Removed redundant double-check for nested observations from ``BaseAlgorithm._wrap_env`` (@TibiGG) Documentation: @@ -326,7 +329,7 @@ Release 1.5.0 (2022-03-25) Breaking Changes: ^^^^^^^^^^^^^^^^^ -- Switched minimum Gym version to 0.21.0. +- Switched minimum Gym version to 0.21.0 New Features: ^^^^^^^^^^^^^ @@ -1251,6 +1254,7 @@ And all the contributors: @eleurent @ac-93 @cove9988 @theDebugger811 @hsuehch @Demetrio92 @thomasgubler @IperGiove @ScheiklP @simoninithomas @armandpl @manuel-delverme @Gautam-J @gianlucadecola @buoyancy99 @caburu @xy9485 @Gregwar @ycheng517 @quantitative-technologies @bcollazo @git-thor @TibiGG @cool-RR @MWeltevrede -@Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875 @yuanmingqi +@carlosluis @arjun-kg @tlpss +@Melanol @qgallouedec @francescoluciano @jlp-ue @burakdmb @timothe-chaumont @honglu2875 @anand-bala @hughperkins @sidney-tio @AlexPasqua @dominicgkerr @Akhilez @Rocamonde @tobirohrer @ZikangXiong @DavyMorgan @luizapozzobon @Bonifatius94 @theSquaredError @harveybellini @DavyMorgan @FieteO diff --git a/pyproject.toml b/pyproject.toml index 7941679ea..b13f568cb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -83,11 +83,8 @@ filterwarnings = [ # Tensorboard warnings "ignore::DeprecationWarning:tensorboard", # Gym warnings - "ignore:Parameters to load are deprecated.:DeprecationWarning", - "ignore:the imp module is deprecated in favour of importlib:PendingDeprecationWarning", "ignore::UserWarning:gym", - "ignore:SelectableGroups dict interface is deprecated.:DeprecationWarning", - "ignore:`np.bool` is a deprecated alias for the builtin `bool`:DeprecationWarning", + "ignore::DeprecationWarning:.*passive_env_checker.*", ] markers = [ "expensive: marks tests as expensive (deselect with '-m \"not expensive\"')" diff --git a/scripts/build_docker.sh b/scripts/build_docker.sh index 13ac86b17..c1a4a5608 100755 --- a/scripts/build_docker.sh +++ b/scripts/build_docker.sh @@ -1,14 +1,14 @@ #!/bin/bash -CPU_PARENT=ubuntu:18.04 -GPU_PARENT=nvidia/cuda:10.1-cudnn7-runtime-ubuntu18.04 +CPU_PARENT=mambaorg/micromamba:1.4-kinetic +GPU_PARENT=mambaorg/micromamba:1.4.1-focal-cuda-11.7.1 TAG=stablebaselines/stable-baselines3 VERSION=$(cat ./stable_baselines3/version.txt) if [[ ${USE_GPU} == "True" ]]; then PARENT=${GPU_PARENT} - PYTORCH_DEPS="cudatoolkit=10.1" + PYTORCH_DEPS="pytorch-cuda=11.7" else PARENT=${CPU_PARENT} PYTORCH_DEPS="cpuonly" diff --git a/scripts/run_docker_cpu.sh b/scripts/run_docker_cpu.sh index 6dfafd2b9..db6c6493b 100755 --- a/scripts/run_docker_cpu.sh +++ b/scripts/run_docker_cpu.sh @@ -7,5 +7,5 @@ echo "Executing in the docker (cpu image):" echo $cmd_line docker run -it --rm --network host --ipc=host \ - --mount src=$(pwd),target=/root/code/stable-baselines3,type=bind stablebaselines/stable-baselines3-cpu:latest \ - bash -c "cd /root/code/stable-baselines3/ && $cmd_line" + --mount src=$(pwd),target=/home/mamba/stable-baselines3,type=bind stablebaselines/stable-baselines3-cpu:latest \ + bash -c "cd /home/mamba/stable-baselines3/ && $cmd_line" diff --git a/scripts/run_docker_gpu.sh b/scripts/run_docker_gpu.sh index 19e16067a..fa8aae9c4 100755 --- a/scripts/run_docker_gpu.sh +++ b/scripts/run_docker_gpu.sh @@ -15,5 +15,5 @@ else fi docker run -it ${NVIDIA_ARG} --rm --network host --ipc=host \ - --mount src=$(pwd),target=/root/code/stable-baselines3,type=bind stablebaselines/stable-baselines3:latest \ - bash -c "cd /root/code/stable-baselines3/ && $cmd_line" + --mount src=$(pwd),target=/home/mamba/stable-baselines3,type=bind stablebaselines/stable-baselines3:latest \ + bash -c "cd /home/mamba/stable-baselines3/ && $cmd_line" diff --git a/setup.py b/setup.py index 7e0043320..4e1aea918 100644 --- a/setup.py +++ b/setup.py @@ -76,6 +76,9 @@ extra_no_roms = [ # For render "opencv-python", + 'pygame; python_version >= "3.8.0"', + # See https://github.com/pygame/pygame/issues/3572 + 'pygame>=2.0,<2.1.3; python_version < "3.8.0"', # Tensorboard support "tensorboard>=2.9.1", # Checking memory taken by replay buffer @@ -84,13 +87,13 @@ "tqdm", "rich", # For atari games, - "ale-py==0.7.4", + "ale-py==0.8.0", "pillow", ] extra_packages = extra_no_roms + [ # noqa: RUF005 # For atari roms, - "autorom[accept-rom-license]~=0.5.5", + "autorom[accept-rom-license]~=0.6.0", ] @@ -99,7 +102,7 @@ packages=[package for package in find_packages() if package.startswith("stable_baselines3")], package_data={"stable_baselines3": ["py.typed", "version.txt"]}, install_requires=[ - "gym==0.21", # Fixed version due to breaking changes in 0.22 + "gym==0.26.2", "numpy", "torch>=1.11", 'typing_extensions>=4.0,<5; python_version < "3.8.0"', @@ -128,8 +131,6 @@ "isort>=5.0", # Reformat "black", - # For toy text Gym envs - "scipy>=1.4.1", ], "docs": [ "sphinx", @@ -138,7 +139,7 @@ # For spelling "sphinxcontrib.spelling", # Type hints support - "sphinx-autodoc-typehints==1.21.1", # TODO: remove version constraint, see #1290 + "sphinx-autodoc-typehints", # Copy button for code snippets "sphinx_copybutton", ], diff --git a/stable_baselines3/__init__.py b/stable_baselines3/__init__.py index 680e25453..0775a8ec5 100644 --- a/stable_baselines3/__init__.py +++ b/stable_baselines3/__init__.py @@ -1,7 +1,5 @@ import os -import numpy as np - from stable_baselines3.a2c import A2C from stable_baselines3.common.utils import get_system_info from stable_baselines3.ddpg import DDPG @@ -11,10 +9,6 @@ from stable_baselines3.sac import SAC from stable_baselines3.td3 import TD3 -# Small monkey patch so gym 0.21 is compatible with numpy >= 1.24 -# TODO: remove when upgrading to gym 0.26 -np.bool = bool # type: ignore[attr-defined] - # Read version from file version_file = os.path.join(os.path.dirname(__file__), "version.txt") with open(version_file) as file_handler: diff --git a/stable_baselines3/common/atari_wrappers.py b/stable_baselines3/common/atari_wrappers.py index ad29a3142..1264d27fa 100644 --- a/stable_baselines3/common/atari_wrappers.py +++ b/stable_baselines3/common/atari_wrappers.py @@ -1,3 +1,5 @@ +from typing import Dict, Tuple + import gym import numpy as np from gym import spaces @@ -9,7 +11,7 @@ except ImportError: cv2 = None -from stable_baselines3.common.type_aliases import GymObs, GymStepReturn +from stable_baselines3.common.type_aliases import Gym26StepReturn class StickyActionEnv(gym.Wrapper): @@ -26,13 +28,13 @@ class StickyActionEnv(gym.Wrapper): def __init__(self, env: gym.Env, action_repeat_probability: float) -> None: super().__init__(env) self.action_repeat_probability = action_repeat_probability - assert env.unwrapped.get_action_meanings()[0] == "NOOP" + assert env.unwrapped.get_action_meanings()[0] == "NOOP" # type: ignore[attr-defined] - def reset(self, **kwargs) -> GymObs: + def reset(self, **kwargs) -> Tuple[np.ndarray, Dict]: self._sticky_action = 0 # NOOP return self.env.reset(**kwargs) - def step(self, action: int) -> GymStepReturn: + def step(self, action: int) -> Gym26StepReturn: if self.np_random.random() >= self.action_repeat_probability: self._sticky_action = action return self.env.step(self._sticky_action) @@ -52,21 +54,22 @@ def __init__(self, env: gym.Env, noop_max: int = 30) -> None: self.noop_max = noop_max self.override_num_noops = None self.noop_action = 0 - assert env.unwrapped.get_action_meanings()[0] == "NOOP" + assert env.unwrapped.get_action_meanings()[0] == "NOOP" # type: ignore[attr-defined] - def reset(self, **kwargs) -> np.ndarray: + def reset(self, **kwargs) -> Tuple[np.ndarray, Dict]: self.env.reset(**kwargs) if self.override_num_noops is not None: noops = self.override_num_noops else: - noops = self.unwrapped.np_random.randint(1, self.noop_max + 1) + noops = self.unwrapped.np_random.integers(1, self.noop_max + 1) assert noops > 0 obs = np.zeros(0) + info: Dict = {} for _ in range(noops): - obs, _, done, _ = self.env.step(self.noop_action) - if done: - obs = self.env.reset(**kwargs) - return obs + obs, _, terminated, truncated, info = self.env.step(self.noop_action) + if terminated or truncated: + obs, info = self.env.reset(**kwargs) + return obs, info class FireResetEnv(gym.Wrapper): @@ -78,18 +81,18 @@ class FireResetEnv(gym.Wrapper): def __init__(self, env: gym.Env) -> None: super().__init__(env) - assert env.unwrapped.get_action_meanings()[1] == "FIRE" - assert len(env.unwrapped.get_action_meanings()) >= 3 + assert env.unwrapped.get_action_meanings()[1] == "FIRE" # type: ignore[attr-defined] + assert len(env.unwrapped.get_action_meanings()) >= 3 # type: ignore[attr-defined] - def reset(self, **kwargs) -> np.ndarray: + def reset(self, **kwargs) -> Tuple[np.ndarray, Dict]: self.env.reset(**kwargs) - obs, _, done, _ = self.env.step(1) - if done: + obs, _, terminated, truncated, _ = self.env.step(1) + if terminated or truncated: self.env.reset(**kwargs) - obs, _, done, _ = self.env.step(2) - if done: + obs, _, terminated, truncated, _ = self.env.step(2) + if terminated or truncated: self.env.reset(**kwargs) - return obs + return obs, {} class EpisodicLifeEnv(gym.Wrapper): @@ -105,21 +108,21 @@ def __init__(self, env: gym.Env) -> None: self.lives = 0 self.was_real_done = True - def step(self, action: int) -> GymStepReturn: - obs, reward, done, info = self.env.step(action) - self.was_real_done = done + def step(self, action: int) -> Gym26StepReturn: + obs, reward, terminated, truncated, info = self.env.step(action) + self.was_real_done = terminated or truncated # check current lives, make loss of life terminal, # then update lives to handle bonus lives - lives = self.env.unwrapped.ale.lives() + lives = self.env.unwrapped.ale.lives() # type: ignore[attr-defined] if 0 < lives < self.lives: # for Qbert sometimes we stay in lives == 0 condition for a few frames # so its important to keep lives > 0, so that we only reset once # the environment advertises done. - done = True + terminated = True self.lives = lives - return obs, reward, done, info + return obs, reward, terminated, truncated, info - def reset(self, **kwargs) -> np.ndarray: + def reset(self, **kwargs) -> Tuple[np.ndarray, Dict]: """ Calls the Gym environment reset, only when lives are exhausted. This way all states are still reachable even though lives are episodic, @@ -129,18 +132,18 @@ def reset(self, **kwargs) -> np.ndarray: :return: the first observation of the environment """ if self.was_real_done: - obs = self.env.reset(**kwargs) + obs, info = self.env.reset(**kwargs) else: # no-op step to advance from terminal/lost life state - obs, _, done, _ = self.env.step(0) + obs, _, terminated, truncated, info = self.env.step(0) # The no-op step can lead to a game over, so we need to check it again # to see if we should reset the environment and avoid the # monitor.py `RuntimeError: Tried to step environment that needs reset` - if done: - obs = self.env.reset(**kwargs) - self.lives = self.env.unwrapped.ale.lives() - return obs + if terminated or truncated: + obs, info = self.env.reset(**kwargs) + self.lives = self.env.unwrapped.ale.lives() # type: ignore[attr-defined] + return obs, info class MaxAndSkipEnv(gym.Wrapper): @@ -156,21 +159,24 @@ class MaxAndSkipEnv(gym.Wrapper): def __init__(self, env: gym.Env, skip: int = 4) -> None: super().__init__(env) # most recent raw observations (for max pooling across time steps) + assert env.observation_space.dtype is not None, "No dtype specified for the observation space" + assert env.observation_space.shape is not None, "No shape defined for the observation space" self._obs_buffer = np.zeros((2, *env.observation_space.shape), dtype=env.observation_space.dtype) self._skip = skip - def step(self, action: int) -> GymStepReturn: + def step(self, action: int) -> Gym26StepReturn: """ Step the environment with the given action Repeat action, sum reward, and max over last observations. :param action: the action - :return: observation, reward, done, information + :return: observation, reward, terminated, truncated, information """ total_reward = 0.0 - done = False + terminated = truncated = False for i in range(self._skip): - obs, reward, done, info = self.env.step(action) + obs, reward, terminated, truncated, info = self.env.step(action) + done = terminated or truncated if i == self._skip - 2: self._obs_buffer[0] = obs if i == self._skip - 1: @@ -182,7 +188,7 @@ def step(self, action: int) -> GymStepReturn: # doesn't matter max_frame = self._obs_buffer.max(axis=0) - return max_frame, total_reward, done, info + return max_frame, total_reward, terminated, truncated, info class ClipRewardEnv(gym.RewardWrapper): @@ -219,8 +225,13 @@ def __init__(self, env: gym.Env, width: int = 84, height: int = 84) -> None: super().__init__(env) self.width = width self.height = height + assert isinstance(env.observation_space, spaces.Box), f"Expected Box space, got {env.observation_space}" + self.observation_space = spaces.Box( - low=0, high=255, shape=(self.height, self.width, 1), dtype=env.observation_space.dtype + low=0, + high=255, + shape=(self.height, self.width, 1), + dtype=env.observation_space.dtype, # type: ignore[arg-type] ) def observation(self, frame: np.ndarray) -> np.ndarray: @@ -285,7 +296,7 @@ def __init__( env = MaxAndSkipEnv(env, skip=frame_skip) if terminal_on_life_loss: env = EpisodicLifeEnv(env) - if "FIRE" in env.unwrapped.get_action_meanings(): + if "FIRE" in env.unwrapped.get_action_meanings(): # type: ignore[attr-defined] env = FireResetEnv(env) env = WarpFrame(env, width=screen_size, height=screen_size) if clip_reward: diff --git a/stable_baselines3/common/base_class.py b/stable_baselines3/common/base_class.py index 0ab03d8b0..e02868f26 100644 --- a/stable_baselines3/common/base_class.py +++ b/stable_baselines3/common/base_class.py @@ -53,7 +53,11 @@ def maybe_make_env(env: Union[GymEnv, str, None], verbose: int) -> Optional[GymE if isinstance(env, str): if verbose >= 1: print(f"Creating environment from the given name '{env}'") - env = gym.make(env) + # Set render_mode to `rgb_array` as default, so we can record video + try: + env = gym.make(env, render_mode="rgb_array") + except TypeError: + env = gym.make(env) return env diff --git a/stable_baselines3/common/env_checker.py b/stable_baselines3/common/env_checker.py index b71454b1c..0cda5e508 100644 --- a/stable_baselines3/common/env_checker.py +++ b/stable_baselines3/common/env_checker.py @@ -60,9 +60,15 @@ def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, act if isinstance(observation_space, spaces.Dict): nested_dict = False - for space in observation_space.spaces.values(): + for key, space in observation_space.spaces.items(): if isinstance(space, spaces.Dict): nested_dict = True + if isinstance(space, spaces.Discrete) and space.start != 0: + warnings.warn( + f"Discrete observation space (key '{key}') with a non-zero start is not supported by Stable-Baselines3. " + "You can use a wrapper or update your observation space." + ) + if nested_dict: warnings.warn( "Nested observation spaces are not supported by Stable Baselines3 " @@ -81,6 +87,18 @@ def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, act "which is supported by SB3." ) + if isinstance(observation_space, spaces.Discrete) and observation_space.start != 0: + warnings.warn( + "Discrete observation space with a non-zero start is not supported by Stable-Baselines3. " + "You can use a wrapper or update your observation space." + ) + + if isinstance(action_space, spaces.Discrete) and action_space.start != 0: + warnings.warn( + "Discrete action space with a non-zero start is not supported by Stable-Baselines3. " + "You can use a wrapper or update your action space." + ) + if not _is_numpy_array_space(action_space): warnings.warn( "The action space is not based off a numpy array. Typically this means it's either a Dict or Tuple space. " @@ -101,9 +119,8 @@ def _is_goal_env(env: gym.Env) -> bool: """ Check if the env uses the convention for goal-conditioned envs (previously, the gym.GoalEnv interface) """ - if isinstance(env, gym.Wrapper): # We need to unwrap the env since gym.Wrapper has the compute_reward method - return _is_goal_env(env.unwrapped) - return hasattr(env, "compute_reward") + # We need to unwrap the env since gym.Wrapper has the compute_reward method + return hasattr(env.unwrapped, "compute_reward") def _check_goal_env_obs(obs: dict, observation_space: spaces.Dict, method_name: str) -> None: @@ -131,7 +148,7 @@ def _check_goal_env_compute_reward( env: gym.Env, reward: float, info: Dict[str, Any], -): +) -> None: """ Check that reward is computed with `compute_reward` and that the implementation is vectorized. @@ -174,27 +191,32 @@ def _check_obs(obs: Union[tuple, dict, np.ndarray, int], observation_space: spac # check obs dimensions, dtype and bounds assert observation_space.shape == obs.shape, ( f"The observation returned by the `{method_name}()` method does not match the shape " - f"of the given observation space. Expected: {observation_space.shape}, actual shape: {obs.shape}" + f"of the given observation space {observation_space}. " + f"Expected: {observation_space.shape}, actual shape: {obs.shape}" ) - assert observation_space.dtype == obs.dtype, ( - f"The observation returned by the `{method_name}()` method does not match the data type " - f"of the given observation space. Expected: {observation_space.dtype}, actual dtype: {obs.dtype}" + assert np.can_cast(obs.dtype, observation_space.dtype), ( + f"The observation returned by the `{method_name}()` method does not match the data type (cannot cast) " + f"of the given observation space {observation_space}. " + f"Expected: {observation_space.dtype}, actual dtype: {obs.dtype}" ) if isinstance(observation_space, spaces.Box): assert np.all(obs >= observation_space.low), ( f"The observation returned by the `{method_name}()` method does not match the lower bound " - f"of the given observation space. Expected: obs >= {np.min(observation_space.low)}, " + f"of the given observation space {observation_space}." + f"Expected: obs >= {np.min(observation_space.low)}, " f"actual min value: {np.min(obs)} at index {np.argmin(obs)}" ) assert np.all(obs <= observation_space.high), ( f"The observation returned by the `{method_name}()` method does not match the upper bound " - f"of the given observation space. Expected: obs <= {np.max(observation_space.high)}, " + f"of the given observation space {observation_space}. " + f"Expected: obs <= {np.max(observation_space.high)}, " f"actual max value: {np.max(obs)} at index {np.argmax(obs)}" ) - assert observation_space.contains( - obs - ), f"The observation returned by the `{method_name}()` method does not match the given observation space" + assert observation_space.contains(obs), ( + f"The observation returned by the `{method_name}()` method " + f"does not match the given observation space {observation_space}" + ) def _check_box_obs(observation_space: spaces.Box, key: str = "") -> None: @@ -222,7 +244,11 @@ def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action Check the returned values by the env when calling `.reset()` or `.step()` methods. """ # because env inherits from gym.Env, we assume that `reset()` and `step()` methods exists - obs = env.reset() + reset_returns = env.reset() + assert isinstance(reset_returns, tuple), "`reset()` must return a tuple (obs, info)" + assert len(reset_returns) == 2, f"`reset()` must return a tuple of size 2 (obs, info), not {len(reset_returns)}" + obs, info = reset_returns + assert isinstance(info, dict), "The second element of the tuple return by `reset()` must be a dictionary" if _is_goal_env(env): # Make mypy happy, already checked @@ -249,19 +275,21 @@ def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action action = action_space.sample() data = env.step(action) - assert len(data) == 4, "The `step()` method must return four values: obs, reward, done, info" + assert len(data) == 5, "The `step()` method must return four values: obs, reward, terminated, truncated, info" # Unpack - obs, reward, done, info = data + obs, reward, terminated, truncated, info = data - if _is_goal_env(env): - # Make mypy happy, already checked - assert isinstance(observation_space, spaces.Dict) - _check_goal_env_obs(obs, observation_space, "step") - _check_goal_env_compute_reward(obs, env, reward, info) - elif isinstance(observation_space, spaces.Dict): + if isinstance(observation_space, spaces.Dict): assert isinstance(obs, dict), "The observation returned by `step()` must be a dictionary" + # Additional checks for GoalEnvs + if _is_goal_env(env): + # Make mypy happy, already checked + assert isinstance(observation_space, spaces.Dict) + _check_goal_env_obs(obs, observation_space, "step") + _check_goal_env_compute_reward(obs, env, reward, info) + if not obs.keys() == observation_space.spaces.keys(): raise AssertionError( "The observation keys returned by `step()` must match the observation " @@ -279,11 +307,14 @@ def _check_returned_values(env: gym.Env, observation_space: spaces.Space, action # We also allow int because the reward will be cast to float assert isinstance(reward, (float, int)), "The reward returned by `step()` must be a float" - assert isinstance(done, bool), "The `done` signal must be a boolean" + assert isinstance(terminated, bool), "The `terminated` signal must be a boolean" + assert isinstance(truncated, bool), "The `truncated` signal must be a boolean" assert isinstance(info, dict), "The `info` returned by `step()` must be a python dictionary" # Goal conditioned env if _is_goal_env(env): + # for mypy, env.unwrapped was checked by _is_goal_env() + assert hasattr(env, "compute_reward") assert reward == env.compute_reward(obs["achieved_goal"], obs["desired_goal"], info) @@ -309,9 +340,9 @@ def _check_spaces(env: gym.Env) -> None: # Check render cannot be covered by CI -def _check_render(env: gym.Env, warn: bool = True, headless: bool = False) -> None: # pragma: no cover +def _check_render(env: gym.Env, warn: bool = False) -> None: # pragma: no cover """ - Check the declared render modes and the `render()`/`close()` + Check the instantiated render mode (if any) by calling the `render()`/`close()` method of the environment. :param env: The environment to check @@ -319,24 +350,20 @@ def _check_render(env: gym.Env, warn: bool = True, headless: bool = False) -> No :param headless: Whether to disable render modes that require a graphical interface. False by default. """ - render_modes = env.metadata.get("render.modes") + render_modes = env.metadata.get("render_modes") if render_modes is None: if warn: warnings.warn( "No render modes was declared in the environment " - " (env.metadata['render.modes'] is None or not defined), " + "(env.metadata['render_modes'] is None or not defined), " "you may have trouble when calling `.render()`" ) - else: - # Don't check render mode that require a - # graphical interface (useful for CI) - if headless and "human" in render_modes: - render_modes.remove("human") - # Check all declared render modes - for render_mode in render_modes: - env.render(mode=render_mode) - env.close() + # TODO: if we want to check all declared render modes, + # we need to initialize new environments so the class should be passed as argument. + if env.render_mode: + env.render() + env.close() def check_env(env: gym.Env, warn: bool = True, skip_render_check: bool = True) -> None: @@ -401,7 +428,7 @@ def check_env(env: gym.Env, warn: bool = True, skip_render_check: bool = True) - # ==== Check the render method and the declared render modes ==== if not skip_render_check: - _check_render(env, warn=warn) # pragma: no cover + _check_render(env, warn) # pragma: no cover try: check_for_nested_spaces(env.observation_space) diff --git a/stable_baselines3/common/env_util.py b/stable_baselines3/common/env_util.py index c85d1472b..cf1024649 100644 --- a/stable_baselines3/common/env_util.py +++ b/stable_baselines3/common/env_util.py @@ -5,6 +5,7 @@ from stable_baselines3.common.atari_wrappers import AtariWrapper from stable_baselines3.common.monitor import Monitor +from stable_baselines3.common.utils import compat_gym_seed from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecEnv @@ -24,7 +25,7 @@ def unwrap_wrapper(env: gym.Env, wrapper_class: Type[gym.Wrapper]) -> Optional[g return None -def is_wrapped(env: Type[gym.Env], wrapper_class: Type[gym.Wrapper]) -> bool: +def is_wrapped(env: gym.Env, wrapper_class: Type[gym.Wrapper]) -> bool: """ Check if a given environment has been wrapped with a given wrapper. @@ -72,25 +73,37 @@ def make_vec_env( :param wrapper_kwargs: Keyword arguments to pass to the ``Wrapper`` class constructor. :return: The wrapped environment """ - env_kwargs = {} if env_kwargs is None else env_kwargs - vec_env_kwargs = {} if vec_env_kwargs is None else vec_env_kwargs - monitor_kwargs = {} if monitor_kwargs is None else monitor_kwargs - wrapper_kwargs = {} if wrapper_kwargs is None else wrapper_kwargs + env_kwargs = env_kwargs or {} + vec_env_kwargs = vec_env_kwargs or {} + monitor_kwargs = monitor_kwargs or {} + wrapper_kwargs = wrapper_kwargs or {} + assert vec_env_kwargs is not None # for mypy + + def make_env(rank: int) -> Callable[[], gym.Env]: + def _init() -> gym.Env: + # For type checker: + assert monitor_kwargs is not None + assert wrapper_kwargs is not None + assert env_kwargs is not None - def make_env(rank): - def _init(): if isinstance(env_id, str): - env = gym.make(env_id, **env_kwargs) + # if the render mode was not specified, we set it to `rgb_array` as default. + kwargs = {"render_mode": "rgb_array"} + kwargs.update(env_kwargs) + try: + env = gym.make(env_id, **kwargs) + except TypeError: + env = gym.make(env_id, **env_kwargs) else: env = env_id(**env_kwargs) if seed is not None: - env.seed(seed + rank) + compat_gym_seed(env, seed=seed + rank) env.action_space.seed(seed + rank) # Wrap the env in a Monitor wrapper # to have additional training information monitor_path = os.path.join(monitor_dir, str(rank)) if monitor_dir is not None else None # Create the monitor folder if needed - if monitor_path is not None: + if monitor_path is not None and monitor_dir is not None: os.makedirs(monitor_dir, exist_ok=True) env = Monitor(env, filename=monitor_path, **monitor_kwargs) # Optionally, wrap the environment with the provided wrapper diff --git a/stable_baselines3/common/envs/bit_flipping_env.py b/stable_baselines3/common/envs/bit_flipping_env.py index d6724c9cc..090985dcc 100644 --- a/stable_baselines3/common/envs/bit_flipping_env.py +++ b/stable_baselines3/common/envs/bit_flipping_env.py @@ -1,11 +1,11 @@ from collections import OrderedDict -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, Tuple, Union import numpy as np from gym import Env, spaces from gym.envs.registration import EnvSpec -from stable_baselines3.common.type_aliases import GymStepReturn +from stable_baselines3.common.type_aliases import Gym26StepReturn class BitFlippingEnv(Env): @@ -25,7 +25,7 @@ class BitFlippingEnv(Env): :param channel_first: Whether to use channel-first or last image. """ - spec = EnvSpec("BitFlippingEnv-v0") + spec = EnvSpec("BitFlippingEnv-v0", "no-entry-point") def __init__( self, @@ -96,7 +96,7 @@ def __init__( self.discrete_obs_space = discrete_obs_space self.image_obs_space = image_obs_space self.state = None - self.desired_goal = np.ones((n_bits,)) + self.desired_goal = np.ones((n_bits,), dtype=self.observation_space["desired_goal"].dtype) if max_steps is None: max_steps = n_bits self.max_steps = max_steps @@ -157,24 +157,34 @@ def _get_obs(self) -> Dict[str, Union[int, np.ndarray]]: ] ) - def reset(self) -> Dict[str, Union[int, np.ndarray]]: + def reset( + self, *, seed: Optional[int] = None, options: Optional[Dict] = None + ) -> Tuple[Dict[str, Union[int, np.ndarray]], Dict]: + if seed is not None: + self.obs_space.seed(seed) self.current_step = 0 self.state = self.obs_space.sample() - return self._get_obs() + return self._get_obs(), {} - def step(self, action: Union[np.ndarray, int]) -> GymStepReturn: + def step(self, action: Union[np.ndarray, int]) -> Gym26StepReturn: + """ + Step into the env. + + :param action: + :return: + """ if self.continuous: self.state[action > 0] = 1 - self.state[action > 0] else: self.state[action] = 1 - self.state[action] obs = self._get_obs() reward = float(self.compute_reward(obs["achieved_goal"], obs["desired_goal"], None)) - done = reward == 0 + terminated = reward == 0 self.current_step += 1 # Episode terminate when we reached the goal or the max number of steps - info = {"is_success": done} - done = done or self.current_step >= self.max_steps - return obs, reward, done, info + info = {"is_success": terminated} + truncated = self.current_step >= self.max_steps + return obs, reward, terminated, truncated, info def compute_reward( self, achieved_goal: Union[int, np.ndarray], desired_goal: Union[int, np.ndarray], _info: Optional[Dict[str, Any]] diff --git a/stable_baselines3/common/envs/identity_env.py b/stable_baselines3/common/envs/identity_env.py index a8bed175a..90e1fdb14 100644 --- a/stable_baselines3/common/envs/identity_env.py +++ b/stable_baselines3/common/envs/identity_env.py @@ -4,7 +4,7 @@ import numpy as np from gym import spaces -from stable_baselines3.common.type_aliases import GymStepReturn +from stable_baselines3.common.type_aliases import Gym26StepReturn T = TypeVar("T", int, np.ndarray) @@ -34,18 +34,21 @@ def __init__(self, dim: Optional[int] = None, space: Optional[spaces.Space] = No self.num_resets = -1 # Becomes 0 after __init__ exits. self.reset() - def reset(self) -> T: + def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None) -> Tuple[T, Dict]: + if seed is not None: + super().reset(seed=seed) self.current_step = 0 self.num_resets += 1 self._choose_next_state() - return self.state + return self.state, {} - def step(self, action: T) -> Tuple[T, float, bool, Dict[str, Any]]: + def step(self, action: T) -> Tuple[T, float, bool, bool, Dict[str, Any]]: reward = self._get_reward(action) self._choose_next_state() self.current_step += 1 - done = self.current_step >= self.ep_length - return self.state, reward, done, {} + terminated = False + truncated = self.current_step >= self.ep_length + return self.state, reward, terminated, truncated, {} def _choose_next_state(self) -> None: self.state = self.action_space.sample() @@ -71,12 +74,13 @@ def __init__(self, low: float = -1.0, high: float = 1.0, eps: float = 0.05, ep_l super().__init__(ep_length=ep_length, space=space) self.eps = eps - def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool, Dict[str, Any]]: + def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool, bool, Dict[str, Any]]: reward = self._get_reward(action) self._choose_next_state() self.current_step += 1 - done = self.current_step >= self.ep_length - return self.state, reward, done, {} + terminated = False + truncated = self.current_step >= self.ep_length + return self.state, reward, terminated, truncated, {} def _get_reward(self, action: np.ndarray) -> float: return 1.0 if (self.state - self.eps) <= action <= (self.state + self.eps) else 0.0 @@ -138,15 +142,18 @@ def __init__( self.ep_length = 10 self.current_step = 0 - def reset(self) -> np.ndarray: + def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None) -> Tuple[np.ndarray, Dict]: + if seed is not None: + super().reset(seed=seed) self.current_step = 0 - return self.observation_space.sample() + return self.observation_space.sample(), {} - def step(self, action: Union[np.ndarray, int]) -> GymStepReturn: + def step(self, action: Union[np.ndarray, int]) -> Gym26StepReturn: reward = 0.0 self.current_step += 1 - done = self.current_step >= self.ep_length - return self.observation_space.sample(), reward, done, {} + terminated = False + truncated = self.current_step >= self.ep_length + return self.observation_space.sample(), reward, terminated, truncated, {} def render(self, mode: str = "human") -> None: pass diff --git a/stable_baselines3/common/envs/multi_input_envs.py b/stable_baselines3/common/envs/multi_input_envs.py index 166c6991a..8fc9ac04f 100644 --- a/stable_baselines3/common/envs/multi_input_envs.py +++ b/stable_baselines3/common/envs/multi_input_envs.py @@ -1,10 +1,10 @@ -from typing import Dict, Union +from typing import Dict, Optional, Tuple, Union import gym import numpy as np from gym import spaces -from stable_baselines3.common.type_aliases import GymStepReturn +from stable_baselines3.common.type_aliases import Gym26StepReturn class SimpleMultiObsEnv(gym.Env): @@ -121,7 +121,7 @@ def init_possible_transitions(self) -> None: self.right_possible = [0, 1, 2, 12, 13, 14] self.up_possible = [4, 8, 12, 7, 11, 15] - def step(self, action: Union[float, np.ndarray]) -> GymStepReturn: + def step(self, action: Union[float, np.ndarray]) -> Gym26StepReturn: """ Run one timestep of the environment's dynamics. When end of episode is reached, you are responsible for calling `reset()` @@ -153,11 +153,12 @@ def step(self, action: Union[float, np.ndarray]) -> GymStepReturn: got_to_end = self.state == self.max_state reward = 1 if got_to_end else reward - done = self.count > self.max_count or got_to_end + truncated = self.count > self.max_count + terminated = got_to_end self.log = f"Went {self.action2str[action]} in state {prev_state}, got to state {self.state}" - return self.get_state_mapping(), reward, done, {"got_to_end": got_to_end} + return self.get_state_mapping(), reward, terminated, truncated, {"got_to_end": got_to_end} def render(self, mode: str = "human") -> None: """ @@ -167,15 +168,18 @@ def render(self, mode: str = "human") -> None: """ print(self.log) - def reset(self) -> Dict[str, np.ndarray]: + def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None) -> Tuple[Dict[str, np.ndarray], Dict]: """ Resets the environment state and step count and returns reset observation. + :param seed: :return: observation dict {'vec': ..., 'img': ...} """ + if seed is not None: + super().reset(seed=seed) self.count = 0 if not self.random_start: self.state = 0 else: self.state = np.random.randint(0, self.max_state) - return self.state_mapping[self.state] + return self.state_mapping[self.state], {} diff --git a/stable_baselines3/common/evaluation.py b/stable_baselines3/common/evaluation.py index b65edf840..593b407d8 100644 --- a/stable_baselines3/common/evaluation.py +++ b/stable_baselines3/common/evaluation.py @@ -59,7 +59,7 @@ def evaluate_policy( from stable_baselines3.common.monitor import Monitor if not isinstance(env, VecEnv): - env = DummyVecEnv([lambda: env]) + env = DummyVecEnv([lambda: env]) # type: ignore[list-item, return-value] is_monitor_wrapped = is_vecenv_wrapped(env, VecMonitor) or env.env_is_wrapped(Monitor)[0] @@ -85,7 +85,12 @@ def evaluate_policy( states = None episode_starts = np.ones((env.num_envs,), dtype=bool) while (episode_counts < episode_count_targets).any(): - actions, states = model.predict(observations, state=states, episode_start=episode_starts, deterministic=deterministic) + actions, states = model.predict( + observations, # type: ignore[arg-type] + state=states, + episode_start=episode_starts, + deterministic=deterministic, + ) observations, rewards, dones, infos = env.step(actions) current_rewards += rewards current_lengths += 1 diff --git a/stable_baselines3/common/monitor.py b/stable_baselines3/common/monitor.py index b8ebc2bac..f478441eb 100644 --- a/stable_baselines3/common/monitor.py +++ b/stable_baselines3/common/monitor.py @@ -11,7 +11,7 @@ import numpy as np import pandas -from stable_baselines3.common.type_aliases import GymObs, GymStepReturn +from stable_baselines3.common.type_aliases import Gym26ResetReturn, Gym26StepReturn class Monitor(gym.Wrapper): @@ -43,9 +43,10 @@ def __init__( self.t_start = time.time() self.results_writer = None if filename is not None: + env_id = env.spec.id if env.spec is not None else None self.results_writer = ResultsWriter( filename, - header={"t_start": self.t_start, "env_id": env.spec and env.spec.id}, + header={"t_start": self.t_start, "env_id": env_id}, extra_keys=reset_keywords + info_keywords, override_existing=override_existing, ) @@ -62,7 +63,7 @@ def __init__( # extra info about the current episode, that was passed in during reset() self.current_reset_info: Dict[str, Any] = {} - def reset(self, **kwargs) -> GymObs: + def reset(self, **kwargs) -> Gym26ResetReturn: """ Calls the Gym environment reset. Can only be called if the environment is over, or if allow_early_resets is True @@ -83,18 +84,18 @@ def reset(self, **kwargs) -> GymObs: self.current_reset_info[key] = value return self.env.reset(**kwargs) - def step(self, action: Union[np.ndarray, int]) -> GymStepReturn: + def step(self, action: Union[np.ndarray, int]) -> Gym26StepReturn: """ Step the environment with the given action :param action: the action - :return: observation, reward, done, information + :return: observation, reward, terminated, truncated, information """ if self.needs_reset: raise RuntimeError("Tried to step environment that needs reset") - observation, reward, done, info = self.env.step(action) + observation, reward, terminated, truncated, info = self.env.step(action) self.rewards.append(reward) - if done: + if terminated or truncated: self.needs_reset = True ep_rew = sum(self.rewards) ep_len = len(self.rewards) @@ -109,7 +110,7 @@ def step(self, action: Union[np.ndarray, int]) -> GymStepReturn: self.results_writer.write_row(ep_info) info["episode"] = ep_info self.total_steps += 1 - return observation, reward, done, info + return observation, reward, terminated, truncated, info def close(self) -> None: """ diff --git a/stable_baselines3/common/preprocessing.py b/stable_baselines3/common/preprocessing.py index e280ed731..79f3fbeec 100644 --- a/stable_baselines3/common/preprocessing.py +++ b/stable_baselines3/common/preprocessing.py @@ -158,10 +158,7 @@ def get_obs_shape( return (int(len(observation_space.nvec)),) elif isinstance(observation_space, spaces.MultiBinary): # Number of binary features - if type(observation_space.n) in [tuple, list, np.ndarray]: - return tuple(observation_space.n) - else: - return (int(observation_space.n),) + return observation_space.shape elif isinstance(observation_space, spaces.Dict): return {key: get_obs_shape(subspace) for (key, subspace) in observation_space.spaces.items()} # type: ignore[misc] @@ -205,18 +202,20 @@ def get_action_dim(action_space: spaces.Space) -> int: return int(len(action_space.nvec)) elif isinstance(action_space, spaces.MultiBinary): # Number of binary actions + assert isinstance(action_space.n, int), ( + "Multi-dimensional MultiBinary action space is not supported. " "You can flatten it instead." + ) return int(action_space.n) else: raise NotImplementedError(f"{action_space} action space is not supported") -def check_for_nested_spaces(obs_space: spaces.Space): +def check_for_nested_spaces(obs_space: spaces.Space) -> None: """ Make sure the observation space does not have nested spaces (Dicts/Tuples inside Dicts/Tuples). If so, raise an Exception informing that there is no support for this. :param obs_space: an observation space - :return: """ if isinstance(obs_space, (spaces.Dict, spaces.Tuple)): sub_spaces = obs_space.spaces.values() if isinstance(obs_space, spaces.Dict) else obs_space.spaces diff --git a/stable_baselines3/common/torch_layers.py b/stable_baselines3/common/torch_layers.py index 44714d6fe..9e7277467 100644 --- a/stable_baselines3/common/torch_layers.py +++ b/stable_baselines3/common/torch_layers.py @@ -63,10 +63,14 @@ class NatureCNN(BaseFeaturesExtractor): def __init__( self, - observation_space: spaces.Box, + observation_space: gym.Space, features_dim: int = 512, normalized_image: bool = False, ) -> None: + assert isinstance(observation_space, spaces.Box), ( + "NatureCNN must be used with a gym.spaces.Box ", + f"observation space, not {observation_space}", + ) super().__init__(observation_space, features_dim) # We assume CxHxW images (channels first) # Re-ordering will be done by pre-preprocessing or wrapper diff --git a/stable_baselines3/common/type_aliases.py b/stable_baselines3/common/type_aliases.py index 7227667a1..037c0e58c 100644 --- a/stable_baselines3/common/type_aliases.py +++ b/stable_baselines3/common/type_aliases.py @@ -17,7 +17,9 @@ GymEnv = Union[gym.Env, vec_env.VecEnv] GymObs = Union[Tuple, Dict[str, Any], np.ndarray, int] +Gym26ResetReturn = Tuple[GymObs, Dict] GymStepReturn = Tuple[GymObs, float, bool, Dict] +Gym26StepReturn = Tuple[GymObs, float, bool, bool, Dict] TensorDict = Dict[Union[str, int], th.Tensor] OptimizerStateDict = Dict[str, Any] MaybeCallback = Union[None, Callable, List[callbacks.BaseCallback], callbacks.BaseCallback] diff --git a/stable_baselines3/common/utils.py b/stable_baselines3/common/utils.py index 1234fba79..3d6eea5ca 100644 --- a/stable_baselines3/common/utils.py +++ b/stable_baselines3/common/utils.py @@ -4,6 +4,7 @@ import random import re from collections import deque +from inspect import signature from itertools import zip_longest from typing import Dict, Iterable, List, Optional, Tuple, Union @@ -541,3 +542,18 @@ def get_system_info(print_info: bool = True) -> Tuple[Dict[str, str], str]: if print_info: print(env_info_str) return env_info, env_info_str + + +def compat_gym_seed(env: GymEnv, seed: int) -> None: + """ + Compatibility helper to seed Gym envs. + + :param env: The Gym environment. + :param seed: The seed for the pseudo random generator + """ + if isinstance(env, gym.Env) and "seed" in signature(env.unwrapped.reset).parameters: + # gym >= 0.23.1 + env.reset(seed=seed) + else: + # VecEnv and backward compatibility + env.seed(seed) diff --git a/stable_baselines3/common/vec_env/base_vec_env.py b/stable_baselines3/common/vec_env/base_vec_env.py index 0b3e1b40e..572a7a132 100644 --- a/stable_baselines3/common/vec_env/base_vec_env.py +++ b/stable_baselines3/common/vec_env/base_vec_env.py @@ -54,12 +54,20 @@ class VecEnv(ABC): :param action_space: Action space """ - metadata = {"render.modes": ["human", "rgb_array"]} + metadata = {"render_modes": ["human", "rgb_array"]} - def __init__(self, num_envs: int, observation_space: spaces.Space, action_space: spaces.Space): + def __init__( + self, + num_envs: int, + observation_space: spaces.Space, + action_space: spaces.Space, + render_mode: Optional[str] = None, + ): self.num_envs = num_envs self.observation_space = observation_space self.action_space = action_space + self.render_mode = render_mode + self.reset_infos = [{} for _ in range(num_envs)] # store info returned by the reset method @abstractmethod def reset(self) -> VecEnvObs: @@ -162,35 +170,72 @@ def step(self, actions: np.ndarray) -> VecEnvStepReturn: self.step_async(actions) return self.step_wait() - def get_images(self) -> Sequence[np.ndarray]: + def get_images(self) -> Sequence[Optional[np.ndarray]]: """ - Return RGB images from each environment + Return RGB images from each environment when available """ raise NotImplementedError - def render(self, mode: str = "human") -> Optional[np.ndarray]: + def render(self, mode: Optional[str] = None) -> Optional[np.ndarray]: """ Gym environment rendering :param mode: the rendering type """ - try: - imgs = self.get_images() - except NotImplementedError: - warnings.warn(f"Render not defined for {self}") + + if mode == "human" and self.render_mode != mode: + # Special case, if the render_mode="rgb_array" + # we can still display that image using opencv + if self.render_mode != "rgb_array": + warnings.warn( + f"You tried to render a VecEnv with mode='{mode}' " + "but the render mode defined when initializing the environment must be " + f"'human' or 'rgb_array', not '{self.render_mode}'." + ) + return + + elif mode and self.render_mode != mode: + warnings.warn( + f"""Starting from gym v0.26, render modes are determined during the initialization of the environment. + We allow to pass a mode argument to maintain a backwards compatible VecEnv API, but the mode ({mode}) + has to be the same as the environment render mode ({self.render_mode}) which is not the case.""" + ) + return + + mode = mode or self.render_mode + + if mode is None: + warnings.warn("You tried to call render() but no `render_mode` was passed to the env constructor.") + return + + # mode == self.render_mode == "human" + # In that case, we try to call `self.env.render()` but it might + # crash for subprocesses + if self.render_mode == "human": + self.env_method("render") return - # Create a big image by tiling images from subprocesses - bigimg = tile_images(imgs) - if mode == "human": - import cv2 # pytype:disable=import-error + if mode == "rgb_array" or mode == "human": + # call the render method of the environments + images = self.get_images() + # Create a big image by tiling images from subprocesses + bigimg = tile_images(images) + + if mode == "human": + # Display it using OpenCV + import cv2 # pytype:disable=import-error + + cv2.imshow("vecenv", bigimg[:, :, ::-1]) + cv2.waitKey(1) + else: + return bigimg - cv2.imshow("vecenv", bigimg[:, :, ::-1]) - cv2.waitKey(1) - elif mode == "rgb_array": - return bigimg else: - raise NotImplementedError(f"Render mode {mode} is not supported by VecEnvs") + # Other render modes: + # In that case, we try to call `self.env.render()` but it might + # crash for subprocesses + # and we don't return the values + self.env_method("render") @abstractmethod def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]: @@ -251,6 +296,7 @@ def __init__( venv: VecEnv, observation_space: Optional[spaces.Space] = None, action_space: Optional[spaces.Space] = None, + render_mode: Optional[str] = None, ): self.venv = venv VecEnv.__init__( @@ -258,6 +304,7 @@ def __init__( num_envs=venv.num_envs, observation_space=observation_space or venv.observation_space, action_space=action_space or venv.action_space, + render_mode=render_mode, ) self.class_attributes = dict(inspect.getmembers(self.__class__)) @@ -278,10 +325,10 @@ def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]: def close(self) -> None: return self.venv.close() - def render(self, mode: str = "human") -> Optional[np.ndarray]: + def render(self, mode: Optional[str] = None) -> Optional[np.ndarray]: return self.venv.render(mode=mode) - def get_images(self) -> Sequence[np.ndarray]: + def get_images(self) -> Sequence[Optional[np.ndarray]]: return self.venv.get_images() def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]: diff --git a/stable_baselines3/common/vec_env/dummy_vec_env.py b/stable_baselines3/common/vec_env/dummy_vec_env.py index 5b9fc8b40..01b1e8dd1 100644 --- a/stable_baselines3/common/vec_env/dummy_vec_env.py +++ b/stable_baselines3/common/vec_env/dummy_vec_env.py @@ -1,3 +1,4 @@ +import warnings from collections import OrderedDict from copy import deepcopy from typing import Any, Callable, List, Optional, Sequence, Type, Union @@ -35,7 +36,7 @@ def __init__(self, env_fns: List[Callable[[], gym.Env]]): "Please read https://github.com/DLR-RM/stable-baselines3/issues/1151 for more information." ) env = self.envs[0] - VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space) + VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space, env.render_mode) obs_space = env.observation_space self.keys, shapes, dtypes = obs_space_info(obs_space) @@ -50,28 +51,38 @@ def step_async(self, actions: np.ndarray) -> None: self.actions = actions def step_wait(self) -> VecEnvStepReturn: + # Avoid circular imports for env_idx in range(self.num_envs): - obs, self.buf_rews[env_idx], self.buf_dones[env_idx], self.buf_infos[env_idx] = self.envs[env_idx].step( + obs, self.buf_rews[env_idx], terminated, truncated, self.buf_infos[env_idx] = self.envs[env_idx].step( self.actions[env_idx] ) + # convert to SB3 VecEnv api + self.buf_dones[env_idx] = terminated or truncated + # See https://github.com/openai/gym/issues/3102 + # Gym 0.26 introduces a breaking change + self.buf_infos[env_idx]["TimeLimit.truncated"] = truncated and not terminated + if self.buf_dones[env_idx]: # save final observation where user can get it, then reset self.buf_infos[env_idx]["terminal_observation"] = obs - obs = self.envs[env_idx].reset() + obs, self.reset_infos[env_idx] = self.envs[env_idx].reset() self._save_obs(env_idx, obs) return (self._obs_from_buf(), np.copy(self.buf_rews), np.copy(self.buf_dones), deepcopy(self.buf_infos)) def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]: + # Avoid circular import + from stable_baselines3.common.utils import compat_gym_seed + if seed is None: seed = np.random.randint(0, 2**32 - 1) seeds = [] for idx, env in enumerate(self.envs): - seeds.append(env.seed(seed + idx)) + seeds.append(compat_gym_seed(env, seed=seed + idx)) return seeds def reset(self) -> VecEnvObs: for env_idx in range(self.num_envs): - obs = self.envs[env_idx].reset() + obs, self.reset_infos[env_idx] = self.envs[env_idx].reset() self._save_obs(env_idx, obs) return self._obs_from_buf() @@ -79,25 +90,22 @@ def close(self) -> None: for env in self.envs: env.close() - def get_images(self) -> Sequence[np.ndarray]: - return [env.render(mode="rgb_array") for env in self.envs] + def get_images(self) -> Sequence[Optional[np.ndarray]]: + if self.render_mode != "rgb_array": + warnings.warn( + f"The render mode is {self.render_mode}, but this method assumes it is `rgb_array` to obtain images." + ) + return [None for _ in self.envs] + return [env.render() for env in self.envs] - def render(self, mode: str = "human") -> Optional[np.ndarray]: + def render(self, mode: Optional[str] = None) -> Optional[np.ndarray]: """ Gym environment rendering. If there are multiple environments then they are tiled together in one image via ``BaseVecEnv.render()``. - Otherwise (if ``self.num_envs == 1``), we pass the render call directly to the - underlying environment. - - Therefore, some arguments such as ``mode`` will have values that are valid - only when ``num_envs == 1``. :param mode: The rendering type. """ - if self.num_envs == 1: - return self.envs[0].render(mode=mode) - else: - return super().render(mode=mode) + return super().render(mode=mode) def _save_obs(self, env_idx: int, obs: VecEnvObs) -> None: for key in self.keys: diff --git a/stable_baselines3/common/vec_env/stacked_observations.py b/stable_baselines3/common/vec_env/stacked_observations.py index ae7aebc54..9fac9735f 100644 --- a/stable_baselines3/common/vec_env/stacked_observations.py +++ b/stable_baselines3/common/vec_env/stacked_observations.py @@ -40,12 +40,12 @@ def __init__( if not isinstance(channels_order, Mapping): channels_order = {key: channels_order for key in observation_space.spaces.keys()} self.sub_stacked_observations = { - key: StackedObservations(num_envs, n_stack, subspace, channels_order[key]) + key: StackedObservations(num_envs, n_stack, subspace, channels_order[key]) # type: ignore[arg-type] for key, subspace in observation_space.spaces.items() } self.stacked_observation_space = spaces.Dict( {key: substack_obs.stacked_observation_space for key, substack_obs in self.sub_stacked_observations.items()} - ) # type: spaces.Dict # make mypy happy + ) # type: Union[spaces.Dict, spaces.Box] # make mypy happy elif isinstance(observation_space, spaces.Box): if isinstance(channels_order, Mapping): raise TypeError("When the observation space is Box, channels_order can't be a dict.") @@ -55,7 +55,11 @@ def __init__( ) low = np.repeat(observation_space.low, n_stack, axis=self.repeat_axis) high = np.repeat(observation_space.high, n_stack, axis=self.repeat_axis) - self.stacked_observation_space = spaces.Box(low=low, high=high, dtype=observation_space.dtype) + self.stacked_observation_space = spaces.Box( + low=low, + high=high, + dtype=observation_space.dtype, # type: ignore[arg-type] + ) self.stacked_obs = np.zeros((num_envs, *self.stacked_shape), dtype=observation_space.dtype) else: raise TypeError( @@ -125,7 +129,7 @@ def stack_observation_space(self, observation_space: Union[spaces.Box, spaces.Di ) low = np.repeat(observation_space.low, self.n_stack, axis=self.repeat_axis) high = np.repeat(observation_space.high, self.n_stack, axis=self.repeat_axis) - return spaces.Box(low=low, high=high, dtype=observation_space.dtype) + return spaces.Box(low=low, high=high, dtype=observation_space.dtype) # type: ignore[arg-type] def reset(self, observation: TObs) -> TObs: """ diff --git a/stable_baselines3/common/vec_env/subproc_vec_env.py b/stable_baselines3/common/vec_env/subproc_vec_env.py index 7ff579d30..73d65106f 100644 --- a/stable_baselines3/common/vec_env/subproc_vec_env.py +++ b/stable_baselines3/common/vec_env/subproc_vec_env.py @@ -1,4 +1,5 @@ import multiprocessing as mp +import warnings from collections import OrderedDict from typing import Any, Callable, List, Optional, Sequence, Tuple, Type, Union @@ -16,30 +17,37 @@ def _worker( - remote: mp.connection.Connection, parent_remote: mp.connection.Connection, env_fn_wrapper: CloudpickleWrapper + remote: mp.connection.Connection, + parent_remote: mp.connection.Connection, + env_fn_wrapper: CloudpickleWrapper, ) -> None: # Import here to avoid a circular import from stable_baselines3.common.env_util import is_wrapped + from stable_baselines3.common.utils import compat_gym_seed parent_remote.close() env = env_fn_wrapper.var() + reset_info = {} while True: try: cmd, data = remote.recv() if cmd == "step": - observation, reward, done, info = env.step(data) + observation, reward, terminated, truncated, info = env.step(data) + # convert to SB3 VecEnv api + done = terminated or truncated + info["TimeLimit.truncated"] = truncated and not terminated if done: # save final observation where user can get it, then reset info["terminal_observation"] = observation - observation = env.reset() - remote.send((observation, reward, done, info)) + observation, reset_info = env.reset() + remote.send((observation, reward, done, info, reset_info)) elif cmd == "seed": - remote.send(env.seed(data)) + remote.send(compat_gym_seed(env, seed=data)) elif cmd == "reset": - observation = env.reset() - remote.send(observation) + observation, reset_info = env.reset() + remote.send((observation, reset_info)) elif cmd == "render": - remote.send(env.render(data)) + remote.send(env.render()) elif cmd == "close": env.close() remote.close() @@ -110,7 +118,10 @@ def __init__(self, env_fns: List[Callable[[], gym.Env]], start_method: Optional[ self.remotes[0].send(("get_spaces", None)) observation_space, action_space = self.remotes[0].recv() - VecEnv.__init__(self, len(env_fns), observation_space, action_space) + + self.remotes[0].send(("get_attr", "render_mode")) + render_mode = self.remotes[0].recv() + VecEnv.__init__(self, len(env_fns), observation_space, action_space, render_mode) def step_async(self, actions: np.ndarray) -> None: for remote, action in zip(self.remotes, actions): @@ -120,7 +131,7 @@ def step_async(self, actions: np.ndarray) -> None: def step_wait(self) -> VecEnvStepReturn: results = [remote.recv() for remote in self.remotes] self.waiting = False - obs, rews, dones, infos = zip(*results) + obs, rews, dones, infos, self.reset_infos = zip(*results) return _flatten_obs(obs, self.observation_space), np.stack(rews), np.stack(dones), infos def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]: @@ -133,7 +144,8 @@ def seed(self, seed: Optional[int] = None) -> List[Union[None, int]]: def reset(self) -> VecEnvObs: for remote in self.remotes: remote.send(("reset", None)) - obs = [remote.recv() for remote in self.remotes] + results = [remote.recv() for remote in self.remotes] + obs, self.reset_infos = zip(*results) return _flatten_obs(obs, self.observation_space) def close(self) -> None: @@ -148,13 +160,17 @@ def close(self) -> None: process.join() self.closed = True - def get_images(self) -> Sequence[np.ndarray]: + def get_images(self) -> Sequence[Optional[np.ndarray]]: + if self.render_mode != "rgb_array": + warnings.warn( + f"The render mode is {self.render_mode}, but this method assumes it is `rgb_array` to obtain images." + ) + return [None for _ in self.remotes] for pipe in self.remotes: - # gather images from subprocesses - # `mode` will be taken into account later - pipe.send(("render", "rgb_array")) - imgs = [pipe.recv() for pipe in self.remotes] - return imgs + # gather render return from subprocesses + pipe.send(("render", None)) + outputs = [pipe.recv() for pipe in self.remotes] + return outputs def get_attr(self, attr_name: str, indices: VecEnvIndices = None) -> List[Any]: """Return attribute from vectorized environment (see base class).""" diff --git a/stable_baselines3/common/vec_env/vec_frame_stack.py b/stable_baselines3/common/vec_env/vec_frame_stack.py index 8a020ddd6..75c80e9e8 100644 --- a/stable_baselines3/common/vec_env/vec_frame_stack.py +++ b/stable_baselines3/common/vec_env/vec_frame_stack.py @@ -35,6 +35,9 @@ def step_wait( return observations, rewards, dones, infos def reset(self) -> Union[np.ndarray, Dict[str, np.ndarray]]: + """ + Reset all environments + """ observation = self.venv.reset() # pytype:disable=annotation-type-mismatch observation = self.stacked_obs.reset(observation) return observation diff --git a/stable_baselines3/common/vec_env/vec_video_recorder.py b/stable_baselines3/common/vec_env/vec_video_recorder.py index 83d058abc..db6999400 100644 --- a/stable_baselines3/common/vec_env/vec_video_recorder.py +++ b/stable_baselines3/common/vec_env/vec_video_recorder.py @@ -47,6 +47,7 @@ def __init__( metadata = temp_env.metadata self.env.metadata = metadata + assert self.env.render_mode == "rgb_array", f"The render_mode must be 'rgb_array', not {self.env.render_mode}" self.record_video_trigger = record_video_trigger self.video_recorder = None @@ -109,4 +110,4 @@ def close(self) -> None: self.close_video_recorder() def __del__(self): - self.close() + self.close_video_recorder() diff --git a/stable_baselines3/version.txt b/stable_baselines3/version.txt index 15159beba..35a785a76 100644 --- a/stable_baselines3/version.txt +++ b/stable_baselines3/version.txt @@ -1 +1 @@ -1.8.0a12 +2.0.0a0 diff --git a/tests/test_buffers.py b/tests/test_buffers.py index 9dc294c6a..8150e2fa1 100644 --- a/tests/test_buffers.py +++ b/tests/test_buffers.py @@ -5,6 +5,7 @@ from gym import spaces from stable_baselines3.common.buffers import DictReplayBuffer, DictRolloutBuffer, ReplayBuffer, RolloutBuffer +from stable_baselines3.common.env_checker import check_env from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.type_aliases import DictReplayBufferSamples, ReplayBufferSamples from stable_baselines3.common.utils import get_device @@ -19,7 +20,7 @@ class DummyEnv(gym.Env): def __init__(self): self.action_space = spaces.Box(1, 5, (1,)) self.observation_space = spaces.Box(1, 5, (1,)) - self._observations = [1, 2, 3, 4, 5] + self._observations = np.array([[1.0], [2.0], [3.0], [4.0], [5.0]], dtype=np.float32) self._rewards = [1, 2, 3, 4, 5] self._t = 0 self._ep_length = 100 @@ -27,15 +28,15 @@ def __init__(self): def reset(self): self._t = 0 obs = self._observations[0] - return obs + return obs, {} def step(self, action): self._t += 1 index = self._t % len(self._observations) obs = self._observations[index] - done = self._t >= self._ep_length + terminated = truncated = self._t >= self._ep_length reward = self._rewards[index] - return obs, reward, done, {} + return obs, reward, terminated, truncated, {} class DummyDictEnv(gym.Env): @@ -48,7 +49,7 @@ def __init__(self): self.action_space = spaces.Box(1, 5, shape=(10, 7)) space = spaces.Box(1, 5, (1,)) self.observation_space = spaces.Dict({"observation": space, "achieved_goal": space, "desired_goal": space}) - self._observations = [1, 2, 3, 4, 5] + self._observations = np.array([[1.0], [2.0], [3.0], [4.0], [5.0]], dtype=np.float32) self._rewards = [1, 2, 3, 4, 5] self._t = 0 self._ep_length = 100 @@ -56,15 +57,22 @@ def __init__(self): def reset(self): self._t = 0 obs = {key: self._observations[0] for key in self.observation_space.spaces.keys()} - return obs + return obs, {} def step(self, action): self._t += 1 index = self._t % len(self._observations) obs = {key: self._observations[index] for key in self.observation_space.spaces.keys()} - done = self._t >= self._ep_length + terminated = truncated = self._t >= self._ep_length reward = self._rewards[index] - return obs, reward, done, {} + return obs, reward, terminated, truncated, {} + + +@pytest.mark.parametrize("env_cls", [DummyEnv, DummyDictEnv]) +def test_env(env_cls): + # Check the env used for testing + # Do not warn for assymetric space + check_env(env_cls(), warn=False, skip_render_check=True) @pytest.mark.parametrize("replay_buffer_cls", [ReplayBuffer, DictReplayBuffer]) diff --git a/tests/test_cnn.py b/tests/test_cnn.py index 55c55ca9c..1c59d6994 100644 --- a/tests/test_cnn.py +++ b/tests/test_cnn.py @@ -45,7 +45,7 @@ def test_cnn(tmp_path, model_class, share_features_extractor): # FakeImageEnv is channel last by default and should be wrapped assert is_vecenv_wrapped(model.get_env(), VecTransposeImage) - obs = env.reset() + obs, _ = env.reset() # Test stochastic predict with channel last input if model_class == DQN: @@ -248,7 +248,7 @@ def test_channel_first_env(tmp_path): assert not is_vecenv_wrapped(model.get_env(), VecTransposeImage) - obs = env.reset() + obs, _ = env.reset() action, _ = model.predict(obs, deterministic=True) diff --git a/tests/test_dict_env.py b/tests/test_dict_env.py index 2c114f613..ebcda8bbf 100644 --- a/tests/test_dict_env.py +++ b/tests/test_dict_env.py @@ -1,9 +1,12 @@ +from typing import Dict, Optional + import gym import numpy as np import pytest from gym import spaces from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3 +from stable_baselines3.common.env_checker import check_env from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.envs import BitFlippingEnv, SimpleMultiObsEnv from stable_baselines3.common.evaluation import evaluate_policy @@ -13,7 +16,7 @@ class DummyDictEnv(gym.Env): """Custom Environment for testing purposes only""" - metadata = {"render.modes": ["human"]} + metadata = {"render_modes": ["human"]} def __init__( self, @@ -66,19 +69,31 @@ def seed(self, seed=None): def step(self, action): reward = 0.0 - done = False - return self.observation_space.sample(), reward, done, {} - - def compute_reward(self, achieved_goal, desired_goal, info): - return np.zeros((len(achieved_goal),)) + terminated = truncated = False + return self.observation_space.sample(), reward, terminated, truncated, {} - def reset(self): - return self.observation_space.sample() + def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): + if seed is not None: + self.observation_space.seed(seed) + return self.observation_space.sample(), {} - def render(self, mode="human"): + def render(self): pass +@pytest.mark.parametrize("use_discrete_actions", [True, False]) +@pytest.mark.parametrize("channel_last", [True, False]) +@pytest.mark.parametrize("nested_dict_obs", [True, False]) +@pytest.mark.parametrize("vec_only", [True, False]) +def test_env(use_discrete_actions, channel_last, nested_dict_obs, vec_only): + # Check the env used for testing + if nested_dict_obs: + with pytest.warns(UserWarning, match="Nested observation spaces are not supported"): + check_env(DummyDictEnv(use_discrete_actions, channel_last, nested_dict_obs, vec_only)) + else: + check_env(DummyDictEnv(use_discrete_actions, channel_last, nested_dict_obs, vec_only)) + + @pytest.mark.parametrize("policy", ["MlpPolicy", "CnnPolicy"]) def test_policy_hint(policy): # Common mistake: using the wrong policy @@ -105,7 +120,7 @@ def test_consistency(model_class): dict_env = gym.wrappers.TimeLimit(dict_env, 100) env = gym.wrappers.FlattenObservation(dict_env) dict_env.seed(10) - obs = dict_env.reset() + obs, _ = dict_env.reset() kwargs = {} n_steps = 256 diff --git a/tests/test_env_checker.py b/tests/test_env_checker.py index 94aeb3c97..9da4fd6d5 100644 --- a/tests/test_env_checker.py +++ b/tests/test_env_checker.py @@ -1,3 +1,5 @@ +from typing import Dict, Optional + import gym import numpy as np import pytest @@ -7,20 +9,24 @@ class ActionDictTestEnv(gym.Env): + metadata = {"render_modes": ["human"]} + render_mode = None + action_space = spaces.Dict({"position": spaces.Discrete(1), "velocity": spaces.Discrete(1)}) observation_space = spaces.Box(low=-1.0, high=2.0, shape=(3,), dtype=np.float32) def step(self, action): observation = np.array([1.0, 1.5, 0.5], dtype=self.observation_space.dtype) reward = 1 - done = True + terminated = True + truncated = False info = {} - return observation, reward, done, info + return observation, reward, terminated, truncated, info def reset(self): - return np.array([1.0, 1.5, 0.5], dtype=self.observation_space.dtype) + return np.array([1.0, 1.5, 0.5], dtype=self.observation_space.dtype), {} - def render(self, mode="human"): + def render(self): pass @@ -94,12 +100,12 @@ def test_check_env_detailed_error(obs_tuple, method): class TestEnv(gym.Env): action_space = spaces.Box(low=-1.0, high=1.0, shape=(3,), dtype=np.float32) - def reset(self): - return wrong_obs if method == "reset" else good_obs + def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): + return wrong_obs if method == "reset" else good_obs, {} def step(self, action): obs = wrong_obs if method == "step" else good_obs - return obs, 0.0, True, {} + return obs, 0.0, True, False, {} TestEnv.observation_space = observation_space diff --git a/tests/test_envs.py b/tests/test_envs.py index 1281bb45f..82bd6a6c0 100644 --- a/tests/test_envs.py +++ b/tests/test_envs.py @@ -75,6 +75,17 @@ def test_bit_flipping(kwargs): # No warnings for custom envs assert len(record) == 0 + # Remove a key, must throw an error + obs_space = env.observation_space.spaces["observation"] + del env.observation_space.spaces["observation"] + with pytest.raises(AssertionError): + check_env(env) + + # Rename a key, must throw an error + env.observation_space.spaces["obs"] = obs_space + with pytest.raises(AssertionError): + check_env(env) + def test_high_dimension_action_space(): """ @@ -87,7 +98,7 @@ def test_high_dimension_action_space(): # Patch to avoid error def patched_step(_action): - return env.observation_space.sample(), 0.0, False, {} + return env.observation_space.sample(), 0.0, False, False, {} env.step = patched_step check_env(env) @@ -110,16 +121,20 @@ def patched_step(_action): spaces.Dict({"position": spaces.Dict({"abs": spaces.Discrete(5), "rel": spaces.Discrete(2)})}), # Small image inside a dict spaces.Dict({"img": spaces.Box(low=0, high=255, shape=(32, 32, 3), dtype=np.uint8)}), + # Non zero start index + spaces.Discrete(3, start=-1), + # Non zero start index inside a Dict + spaces.Dict({"obs": spaces.Discrete(3, start=1)}), ], ) def test_non_default_spaces(new_obs_space): env = FakeImageEnv() env.observation_space = new_obs_space # Patch methods to avoid errors - env.reset = new_obs_space.sample + env.reset = lambda: (new_obs_space.sample(), {}) def patched_step(_action): - return new_obs_space.sample(), 0.0, False, {} + return new_obs_space.sample(), 0.0, False, False, {} env.step = patched_step with pytest.warns(UserWarning): @@ -145,6 +160,8 @@ def patched_step(_action): spaces.Box(low=-np.inf, high=1, shape=(2,), dtype=np.float32), # Almost good, except for one dim spaces.Box(low=np.array([-1, -1, -1]), high=np.array([1, 1, 0.99]), dtype=np.float32), + # Non zero start index + spaces.Discrete(3, start=-1), ], ) def test_non_default_action_spaces(new_action_space): @@ -155,14 +172,26 @@ def test_non_default_action_spaces(new_action_space): # No warnings for custom envs assert len(record) == 0 + # Change the action space env.action_space = new_action_space + # Discrete action space + if isinstance(new_action_space, spaces.Discrete): + with pytest.warns(UserWarning): + check_env(env) + return + + low, high = new_action_space.low[0], new_action_space.high[0] # Unbounded action space throws an error, # the rest only warning if not np.all(np.isfinite(env.action_space.low)): with pytest.raises(AssertionError), pytest.warns(UserWarning): check_env(env) + # numpy >= 1.21 raises a ValueError + elif int(np.__version__.split(".")[1]) >= 21 and (low > high): + with pytest.raises(ValueError), pytest.warns(UserWarning): + check_env(env) else: with pytest.warns(UserWarning): check_env(env) @@ -176,7 +205,7 @@ def check_reset_assert_error(env, new_reset_return): """ def wrong_reset(): - return new_reset_return + return new_reset_return, {} # Patch the reset method with a wrong one env.reset = wrong_reset @@ -194,6 +223,11 @@ def test_common_failures_reset(): # The observation is not a numpy array check_reset_assert_error(env, 1) + # Return only obs (gym < 0.26) + env.reset = env.observation_space.sample + with pytest.raises(AssertionError): + check_env(env) + # Return not only the observation check_reset_assert_error(env, (env.observation_space.sample(), False)) @@ -206,10 +240,10 @@ def test_common_failures_reset(): wrong_obs = {**env.observation_space.sample(), "extra_key": None} check_reset_assert_error(env, wrong_obs) - obs = env.reset() + obs, _ = env.reset() def wrong_reset(self): - return {"img": obs["img"], "vec": obs["img"]} + return {"img": obs["img"], "vec": obs["img"]}, {} env.reset = types.MethodType(wrong_reset, env) with pytest.raises(AssertionError) as excinfo: @@ -242,33 +276,38 @@ def test_common_failures_step(): env = IdentityEnvBox() # Wrong shape for the observation - check_step_assert_error(env, (np.ones((4,)), 1.0, False, {})) + check_step_assert_error(env, (np.ones((4,)), 1.0, False, False, {})) # Obs is not a numpy array - check_step_assert_error(env, (1, 1.0, False, {})) + check_step_assert_error(env, (1, 1.0, False, False, {})) # Return a wrong reward - check_step_assert_error(env, (env.observation_space.sample(), np.ones(1), False, {})) + check_step_assert_error(env, (env.observation_space.sample(), np.ones(1), False, False, {})) # Info dict is not returned - check_step_assert_error(env, (env.observation_space.sample(), 0.0, False)) + check_step_assert_error(env, (env.observation_space.sample(), 0.0, False, False)) + + # Truncated is not returned (gym < 0.26) + check_step_assert_error(env, (env.observation_space.sample(), 0.0, False, {})) # Done is not a boolean - check_step_assert_error(env, (env.observation_space.sample(), 0.0, 3.0, {})) - check_step_assert_error(env, (env.observation_space.sample(), 0.0, 1, {})) + check_step_assert_error(env, (env.observation_space.sample(), 0.0, 3.0, False, {})) + check_step_assert_error(env, (env.observation_space.sample(), 0.0, 1, False, {})) + # Truncated is not a boolean + check_step_assert_error(env, (env.observation_space.sample(), 0.0, False, 1.0, {})) env = SimpleMultiObsEnv() # Observation keys and observation space keys must match wrong_obs = env.observation_space.sample() wrong_obs.pop("img") - check_step_assert_error(env, (wrong_obs, 0.0, False, {})) + check_step_assert_error(env, (wrong_obs, 0.0, False, False, {})) wrong_obs = {**env.observation_space.sample(), "extra_key": None} - check_step_assert_error(env, (wrong_obs, 0.0, False, {})) + check_step_assert_error(env, (wrong_obs, 0.0, False, False, {})) - obs = env.reset() + obs, _ = env.reset() def wrong_step(self, action): - return {"img": obs["vec"], "vec": obs["vec"]}, 0.0, False, {} + return {"img": obs["vec"], "vec": obs["vec"]}, 0.0, False, False, {} env.step = types.MethodType(wrong_step, env) with pytest.raises(AssertionError) as excinfo: diff --git a/tests/test_gae.py b/tests/test_gae.py index c90470f00..58e3e4158 100644 --- a/tests/test_gae.py +++ b/tests/test_gae.py @@ -1,3 +1,5 @@ +from typing import Dict, Optional + import gym import numpy as np import pytest @@ -6,6 +8,7 @@ from stable_baselines3 import A2C, PPO, SAC from stable_baselines3.common.callbacks import BaseCallback +from stable_baselines3.common.env_checker import check_env from stable_baselines3.common.policies import ActorCriticPolicy @@ -20,20 +23,26 @@ def __init__(self, max_steps=8): def seed(self, seed): self.observation_space.seed(seed) - def reset(self): + def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): + if seed is not None: + self.observation_space.seed(seed) self.n_steps = 0 - return self.observation_space.sample() + return self.observation_space.sample(), {} def step(self, action): self.n_steps += 1 - done = False + terminated = truncated = False reward = 0.0 if self.n_steps >= self.max_steps: reward = 1.0 - done = True + terminated = True + # To simplify GAE computation checks, + # we do not consider truncation here. + # Truncations are checked in InfiniteHorizonEnv + truncated = False - return self.observation_space.sample(), reward, done, {} + return self.observation_space.sample(), reward, terminated, truncated, {} class InfiniteHorizonEnv(gym.Env): @@ -44,13 +53,16 @@ def __init__(self, n_states=4): self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) self.current_state = 0 - def reset(self): + def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): + if seed is not None: + super().reset(seed=seed) + self.current_state = 0 - return self.current_state + return self.current_state, {} def step(self, action): self.current_state = (self.current_state + 1) % self.n_states - return self.current_state, 1.0, False, {} + return self.current_state, 1.0, False, False, {} class CheckGAECallback(BaseCallback): @@ -110,6 +122,12 @@ def forward(self, obs, deterministic=False): return actions, values, log_prob +@pytest.mark.parametrize("env_cls", [CustomEnv, InfiniteHorizonEnv]) +def test_env(env_cls): + # Check the env used for testing + check_env(env_cls(), skip_render_check=True) + + @pytest.mark.parametrize("model_class", [A2C, PPO]) @pytest.mark.parametrize("gae_lambda", [1.0, 0.9]) @pytest.mark.parametrize("gamma", [1.0, 0.99]) diff --git a/tests/test_her.py b/tests/test_her.py index f9794d5e2..1db12be41 100644 --- a/tests/test_her.py +++ b/tests/test_her.py @@ -251,7 +251,7 @@ def env_fn(): train_freq=4, buffer_size=int(2e4), policy_kwargs=dict(net_arch=[64]), - seed=1, + seed=0, ) model.learn(200) old_replay_buffer = deepcopy(model.replay_buffer) diff --git a/tests/test_identity.py b/tests/test_identity.py index cc7746bc7..3118a4d7a 100644 --- a/tests/test_identity.py +++ b/tests/test_identity.py @@ -15,21 +15,17 @@ def test_discrete(model_class, env): env_ = DummyVecEnv([lambda: env]) kwargs = {} - n_steps = 3000 + n_steps = 2500 if model_class == DQN: kwargs = dict(learning_starts=0) - n_steps = 4000 # DQN only support discrete actions if isinstance(env, (IdentityEnvMultiDiscrete, IdentityEnvMultiBinary)): return - elif model_class == A2C: - # slightly higher budget - n_steps = 3500 - model = model_class("MlpPolicy", env_, gamma=0.4, seed=1, **kwargs).learn(n_steps) + model = model_class("MlpPolicy", env_, gamma=0.4, seed=3, **kwargs).learn(n_steps) evaluate_policy(model, env_, n_eval_episodes=20, reward_threshold=90, warn=False) - obs = env.reset() + obs, _ = env.reset() assert np.shape(model.predict(obs)[0]) == np.shape(obs) @@ -38,16 +34,19 @@ def test_discrete(model_class, env): def test_continuous(model_class): env = IdentityEnvBox(eps=0.5) - n_steps = {A2C: 3500, PPO: 3000, SAC: 700, TD3: 500, DDPG: 500}[model_class] + n_steps = {A2C: 2000, PPO: 2000, SAC: 400, TD3: 400, DDPG: 400}[model_class] kwargs = dict(policy_kwargs=dict(net_arch=[64, 64]), seed=0, gamma=0.95) + if model_class in [TD3]: n_actions = 1 action_noise = NormalActionNoise(mean=np.zeros(n_actions), sigma=0.1 * np.ones(n_actions)) kwargs["action_noise"] = action_noise elif model_class in [A2C]: kwargs["policy_kwargs"]["log_std_init"] = -0.5 + elif model_class == PPO: + kwargs = dict(n_steps=512, n_epochs=5) - model = model_class("MlpPolicy", env, **kwargs).learn(n_steps) + model = model_class("MlpPolicy", env, learning_rate=1e-3, **kwargs).learn(n_steps) evaluate_policy(model, env, n_eval_episodes=20, reward_threshold=90, warn=False) diff --git a/tests/test_logger.py b/tests/test_logger.py index 1bc11e521..8dec3bda5 100644 --- a/tests/test_logger.py +++ b/tests/test_logger.py @@ -14,6 +14,7 @@ from pandas.errors import EmptyDataError from stable_baselines3 import A2C, DQN +from stable_baselines3.common.env_checker import check_env from stable_baselines3.common.logger import ( DEBUG, INFO, @@ -352,12 +353,18 @@ def __init__(self, delay: float = 0.01): self.action_space = spaces.Discrete(2) def reset(self): - return self.observation_space.sample() + return self.observation_space.sample(), {} def step(self, action): time.sleep(self.delay) obs = self.observation_space.sample() - return obs, 0.0, True, {} + return obs, 0.0, True, False, {} + + +@pytest.mark.parametrize("env_cls", [TimeDelayEnv]) +def test_env(env_cls): + # Check the env used for testing + check_env(env_cls(), skip_render_check=True) class InMemoryLogger(Logger): diff --git a/tests/test_monitor.py b/tests/test_monitor.py index 17002f39a..481ef2178 100644 --- a/tests/test_monitor.py +++ b/tests/test_monitor.py @@ -13,7 +13,7 @@ def test_monitor(tmp_path): Test the monitor wrapper """ env = gym.make("CartPole-v1") - env.seed(0) + env.reset(seed=0) monitor_file = os.path.join(str(tmp_path), f"stable_baselines-test-{uuid.uuid4()}.monitor.csv") monitor_env = Monitor(env, monitor_file) monitor_env.reset() @@ -22,10 +22,10 @@ def test_monitor(tmp_path): ep_lengths = [] ep_len, ep_reward = 0, 0 for _ in range(total_steps): - _, reward, done, _ = monitor_env.step(monitor_env.action_space.sample()) + _, reward, terminated, truncated, _ = monitor_env.step(monitor_env.action_space.sample()) ep_len += 1 ep_reward += reward - if done: + if terminated or truncated: ep_rewards.append(ep_reward) ep_lengths.append(ep_len) monitor_env.reset() @@ -64,7 +64,7 @@ def test_monitor_load_results(tmp_path): """ tmp_path = str(tmp_path) env1 = gym.make("CartPole-v1") - env1.seed(0) + env1.reset(seed=0) monitor_file1 = os.path.join(tmp_path, f"stable_baselines-test-{uuid.uuid4()}.monitor.csv") monitor_env1 = Monitor(env1, monitor_file1) @@ -75,8 +75,8 @@ def test_monitor_load_results(tmp_path): monitor_env1.reset() episode_count1 = 0 for _ in range(1000): - _, _, done, _ = monitor_env1.step(monitor_env1.action_space.sample()) - if done: + _, _, terminated, truncated, _ = monitor_env1.step(monitor_env1.action_space.sample()) + if terminated or truncated: episode_count1 += 1 monitor_env1.reset() @@ -84,7 +84,7 @@ def test_monitor_load_results(tmp_path): assert results_size1 == episode_count1 env2 = gym.make("CartPole-v1") - env2.seed(0) + env2.reset(seed=0) monitor_file2 = os.path.join(tmp_path, f"stable_baselines-test-{uuid.uuid4()}.monitor.csv") monitor_env2 = Monitor(env2, monitor_file2) monitor_files = get_monitor_files(tmp_path) @@ -98,8 +98,8 @@ def test_monitor_load_results(tmp_path): monitor_env2 = Monitor(env2, monitor_file2, override_existing=False) monitor_env2.reset() for _ in range(1000): - _, _, done, _ = monitor_env2.step(monitor_env2.action_space.sample()) - if done: + _, _, terminated, truncated, _ = monitor_env2.step(monitor_env2.action_space.sample()) + if terminated or truncated: episode_count2 += 1 monitor_env2.reset() diff --git a/tests/test_predict.py b/tests/test_predict.py index 579abff77..4f8c3d1e9 100644 --- a/tests/test_predict.py +++ b/tests/test_predict.py @@ -5,6 +5,7 @@ from gym import spaces from stable_baselines3 import A2C, DQN, PPO, SAC, TD3 +from stable_baselines3.common.env_checker import check_env from stable_baselines3.common.envs import IdentityEnv from stable_baselines3.common.utils import get_device from stable_baselines3.common.vec_env import DummyVecEnv @@ -30,10 +31,16 @@ def __init__(self): self.action_space = SubClassedBox(-1, 1, shape=(2,), dtype=np.float32) def reset(self): - return self.observation_space.sample() + return self.observation_space.sample(), {} def step(self, action): - return self.observation_space.sample(), 0.0, np.random.rand() > 0.5, {} + return self.observation_space.sample(), 0.0, np.random.rand() > 0.5, False, {} + + +@pytest.mark.parametrize("env_cls", [CustomSubClassedSpaceEnv]) +def test_env(env_cls): + # Check the env used for testing + check_env(env_cls(), skip_render_check=True) @pytest.mark.parametrize("model_class", MODEL_LIST) @@ -70,7 +77,7 @@ def test_predict(model_class, env_id, device): env = gym.make(env_id) vec_env = DummyVecEnv([lambda: gym.make(env_id), lambda: gym.make(env_id)]) - obs = env.reset() + obs, _ = env.reset() action, _ = model.predict(obs) assert isinstance(action, np.ndarray) assert action.shape == env.action_space.shape @@ -96,7 +103,7 @@ def test_dqn_epsilon_greedy(): env = IdentityEnv(2) model = DQN("MlpPolicy", env) model.exploration_rate = 1.0 - obs = env.reset() + obs, _ = env.reset() # is vectorized should not crash with discrete obs action, _ = model.predict(obs, deterministic=False) assert env.action_space.contains(action) @@ -107,5 +114,5 @@ def test_subclassed_space_env(model_class): env = CustomSubClassedSpaceEnv() model = model_class("MlpPolicy", env, policy_kwargs=dict(net_arch=[32])) model.learn(300) - obs = env.reset() + obs, _ = env.reset() env.step(model.predict(obs)) diff --git a/tests/test_spaces.py b/tests/test_spaces.py index 6dd6dc419..82fdc636a 100644 --- a/tests/test_spaces.py +++ b/tests/test_spaces.py @@ -1,9 +1,12 @@ +from typing import Dict, Optional + import gym import numpy as np import pytest from gym import spaces from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3 +from stable_baselines3.common.env_checker import check_env from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.evaluation import evaluate_policy @@ -14,11 +17,13 @@ def __init__(self, nvec): self.observation_space = spaces.MultiDiscrete(nvec) self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) - def reset(self): - return self.observation_space.sample() + def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): + if seed is not None: + super().reset(seed=seed) + return self.observation_space.sample(), {} def step(self, action): - return self.observation_space.sample(), 0.0, False, {} + return self.observation_space.sample(), 0.0, False, False, {} class DummyMultiBinary(gym.Env): @@ -27,11 +32,13 @@ def __init__(self, n): self.observation_space = spaces.MultiBinary(n) self.action_space = spaces.Box(low=-1, high=1, shape=(2,), dtype=np.float32) - def reset(self): - return self.observation_space.sample() + def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): + if seed is not None: + super().reset(seed=seed) + return self.observation_space.sample(), {} def step(self, action): - return self.observation_space.sample(), 0.0, False, {} + return self.observation_space.sample(), 0.0, False, False, {} class DummyMultidimensionalAction(gym.Env): @@ -41,10 +48,16 @@ def __init__(self): self.action_space = spaces.Box(low=-1, high=1, shape=(2, 2), dtype=np.float32) def reset(self): - return self.observation_space.sample() + return self.observation_space.sample(), {} def step(self, action): - return self.observation_space.sample(), 0.0, False, {} + return self.observation_space.sample(), 0.0, False, False, {} + + +@pytest.mark.parametrize("env", [DummyMultiDiscreteSpace([4, 3]), DummyMultiBinary(8), DummyMultiBinary((3, 2))]) +def test_env(env): + # Check the env used for testing + check_env(env, skip_render_check=True) @pytest.mark.parametrize("model_class", [SAC, TD3, DQN]) diff --git a/tests/test_utils.py b/tests/test_utils.py index 0bab41e9e..712c53e32 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -229,17 +229,18 @@ def __init__(self, env): self.needs_reset = True def step(self, action): - obs, reward, done, info = self.env.step(action) - self.needs_reset = done + obs, reward, terminated, truncated, info = self.env.step(action) + self.needs_reset = terminated or truncated self.last_obs = obs - return obs, reward, True, info + return obs, reward, True, truncated, info def reset(self, **kwargs): + info = {} if self.needs_reset: - obs = self.env.reset(**kwargs) + obs, info = self.env.reset(**kwargs) self.last_obs = obs self.needs_reset = False - return self.last_obs + return self.last_obs, info @pytest.mark.parametrize("n_envs", [1, 2, 5, 7]) diff --git a/tests/test_vec_check_nan.py b/tests/test_vec_check_nan.py index 962355782..48b203e89 100644 --- a/tests/test_vec_check_nan.py +++ b/tests/test_vec_check_nan.py @@ -24,13 +24,13 @@ def step(action): obs = float("inf") else: obs = 0 - return [obs], 0.0, False, {} + return [obs], 0.0, False, False, {} @staticmethod def reset(): - return [0.0] + return [0.0], {} - def render(self, mode="human", close=False): + def render(self): pass diff --git a/tests/test_vec_envs.py b/tests/test_vec_envs.py index ae05947c5..17d371b64 100644 --- a/tests/test_vec_envs.py +++ b/tests/test_vec_envs.py @@ -2,12 +2,16 @@ import functools import itertools import multiprocessing +import os +import warnings +from typing import Dict, Optional import gym import numpy as np import pytest from gym import spaces +from stable_baselines3.common.env_util import make_vec_env from stable_baselines3.common.monitor import Monitor from stable_baselines3.common.vec_env import DummyVecEnv, SubprocVecEnv, VecFrameStack, VecNormalize @@ -17,7 +21,7 @@ class CustomGymEnv(gym.Env): - def __init__(self, space): + def __init__(self, space, render_mode: str = "rgb_array"): """ Custom gym environment for testing purposes """ @@ -25,24 +29,27 @@ def __init__(self, space): self.observation_space = space self.current_step = 0 self.ep_length = 4 + self.render_mode = render_mode - def reset(self): + def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): + if seed is not None: + self.seed(seed) self.current_step = 0 self._choose_next_state() - return self.state + return self.state, {} def step(self, action): reward = float(np.random.rand()) self._choose_next_state() self.current_step += 1 - done = self.current_step >= self.ep_length - return self.state, reward, done, {} + terminated = truncated = self.current_step >= self.ep_length + return self.state, reward, terminated, truncated, {} def _choose_next_state(self): self.state = self.observation_space.sample() - def render(self, mode="human"): - if mode == "rgb_array": + def render(self): + if self.render_mode == "rgb_array": return np.zeros((4, 4, 3)) def seed(self, seed=None): @@ -91,9 +98,20 @@ def make_env(): # Test seed method vec_env.seed(0) + # Test render method call - # vec_env.render() # we need a X server to test the "human" mode - vec_env.render(mode="rgb_array") + array_explicit_mode = vec_env.render(mode="rgb_array") + # test render without argument (new gym API style) + array_implicit_mode = vec_env.render() + assert np.array_equal(array_implicit_mode, array_explicit_mode) + + # test warning if you try different render mode + with pytest.warns(UserWarning): + vec_env.render(mode="something_else") + + # we need a X server to test the "human" mode (uses OpenCV) + # vec_env.render(mode="human") + env_method_results = vec_env.env_method("custom_method", 1, indices=None, dim_1=2) setattr_results = [] # Set current_step to an arbitrary value @@ -155,13 +173,13 @@ def __init__(self, max_steps): def reset(self): self.current_step = 0 - return np.array([self.current_step], dtype="int") + return np.array([self.current_step], dtype="int"), {} def step(self, action): prev_step = self.current_step self.current_step += 1 - done = self.current_step >= self.max_steps - return np.array([prev_step], dtype="int"), 0.0, done, {} + terminated = truncated = self.current_step >= self.max_steps + return np.array([prev_step], dtype="int"), 0.0, terminated, truncated, {} @pytest.mark.parametrize("vec_env_class", VEC_ENV_CLASSES) @@ -455,6 +473,23 @@ def make_monitored_env(): assert vec_env.env_is_wrapped(Monitor) == [False, True] +@pytest.mark.parametrize("vec_env_class", VEC_ENV_CLASSES) +def test_backward_compat_seed(vec_env_class): + def make_env(): + env = CustomGymEnv(gym.spaces.Box(low=np.zeros(2), high=np.ones(2))) + # Patch reset function to remove seed param + env.reset = lambda: (env.observation_space.sample(), {}) + env.seed = env.observation_space.seed + return env + + vec_env = vec_env_class([make_env for _ in range(N_ENVS)]) + vec_env.seed(3) + obs = vec_env.reset() + vec_env.seed(3) + new_obs = vec_env.reset() + assert np.allclose(new_obs, obs) + + @pytest.mark.parametrize("vec_env_class", VEC_ENV_CLASSES) def test_vec_seeding(vec_env_class): def make_env(): @@ -484,3 +519,63 @@ def make_env(): assert not np.allclose(rewards[1], rewards[2]) vec_env.close() + + +@pytest.mark.parametrize("vec_env_class", VEC_ENV_CLASSES) +def test_render(vec_env_class): + # Skip if no X-Server + if not os.environ.get("DISPLAY"): + pytest.skip("No X-Server") + + env_id = "Pendulum-v1" + # DummyVecEnv human render is currently + # buggy because of gym: + # https://github.com/carlosluis/stable-baselines3/pull/3#issuecomment-1356863808 + n_envs = 2 + # Human render + vec_env = make_vec_env( + env_id, + n_envs, + vec_env_cls=vec_env_class, + env_kwargs=dict(render_mode="human"), + ) + + vec_env.reset() + vec_env.render() + + with pytest.warns(UserWarning): + vec_env.render("rgb_array") + + with pytest.warns(UserWarning): + vec_env.render(mode="blah") + + for _ in range(10): + vec_env.step([vec_env.action_space.sample() for _ in range(n_envs)]) + vec_env.render() + + vec_env.close() + # rgb_array render, which allows human_render + # thanks to OpenCV + vec_env = make_vec_env( + env_id, + n_envs, + vec_env_cls=vec_env_class, + env_kwargs=dict(render_mode="rgb_array"), + ) + + vec_env.reset() + with warnings.catch_warnings(record=True) as record: + vec_env.render() + vec_env.render("rgb_array") + vec_env.render(mode="human") + + # No warnings for using human mode + assert len(record) == 0 + + with pytest.warns(UserWarning): + vec_env.render(mode="blah") + + for _ in range(10): + vec_env.step([vec_env.action_space.sample() for _ in range(n_envs)]) + vec_env.render() + vec_env.close() diff --git a/tests/test_vec_extract_dict_obs.py b/tests/test_vec_extract_dict_obs.py index 15074425e..6aa4abdbd 100644 --- a/tests/test_vec_extract_dict_obs.py +++ b/tests/test_vec_extract_dict_obs.py @@ -29,7 +29,7 @@ def step_wait(self): def reset(self): return {"rgb": np.zeros((self.num_envs, 86, 86))} - def render(self, mode="human", close=False): + def render(self, close=False): pass diff --git a/tests/test_vec_monitor.py b/tests/test_vec_monitor.py index 0a146a057..ab988b6b4 100644 --- a/tests/test_vec_monitor.py +++ b/tests/test_vec_monitor.py @@ -2,6 +2,7 @@ import json import os import uuid +import warnings import gym import pandas @@ -132,8 +133,9 @@ def test_vec_monitor_ppo(recwarn): """ Test the `VecMonitor` with PPO """ + warnings.filterwarnings(action="ignore", category=DeprecationWarning, module=r".*passive_env_checker") env = DummyVecEnv([lambda: gym.make("CartPole-v1")]) - env.seed(0) + env.seed(seed=0) monitor_env = VecMonitor(env) model = PPO("MlpPolicy", monitor_env, verbose=1, n_steps=64, device="cpu") model.learn(total_timesteps=250) diff --git a/tests/test_vec_normalize.py b/tests/test_vec_normalize.py index 27bba9aad..f799bba4d 100644 --- a/tests/test_vec_normalize.py +++ b/tests/test_vec_normalize.py @@ -1,5 +1,5 @@ import operator -from typing import Any, Dict +from typing import Any, Dict, Optional import gym import numpy as np @@ -35,11 +35,14 @@ def step(self, action): self.t += 1 index = (self.t + self.return_reward_idx) % len(self.returned_rewards) returned_value = self.returned_rewards[index] - return np.array([returned_value]), returned_value, self.t == len(self.returned_rewards), {} + terminated = truncated = self.t == len(self.returned_rewards) + return np.array([returned_value]), returned_value, terminated, truncated, {} - def reset(self): + def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): + if seed is not None: + super().reset(seed=seed) self.t = 0 - return np.array([self.returned_rewards[self.return_reward_idx]]) + return np.array([self.returned_rewards[self.return_reward_idx]]), {} class DummyDictEnv(gym.Env): @@ -58,14 +61,16 @@ def __init__(self): ) self.action_space = spaces.Box(low=-1, high=1, shape=(3,), dtype=np.float32) - def reset(self): - return self.observation_space.sample() + def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): + if seed is not None: + super().reset(seed=seed) + return self.observation_space.sample(), {} def step(self, action): obs = self.observation_space.sample() reward = self.compute_reward(obs["achieved_goal"], obs["desired_goal"], {}) done = np.random.rand() > 0.8 - return obs, reward, done, {} + return obs, reward, done, False, {} def compute_reward(self, achieved_goal: np.ndarray, desired_goal: np.ndarray, _info) -> np.float32: distance = np.linalg.norm(achieved_goal - desired_goal, axis=-1) @@ -88,13 +93,15 @@ def __init__(self): ) self.action_space = spaces.Box(low=-1, high=1, shape=(3,), dtype=np.float32) - def reset(self): - return self.observation_space.sample() + def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None): + if seed is not None: + super().reset(seed=seed) + return self.observation_space.sample(), {} def step(self, action): obs = self.observation_space.sample() done = np.random.rand() > 0.8 - return obs, 0.0, done, {} + return obs, 0.0, done, False, {} def allclose(obs_1, obs_2):