diff --git a/tools/workspace/stable_baselines3_internal/connection.patch b/tools/workspace/stable_baselines3_internal/connection.patch deleted file mode 100644 index 956e2270eeb1..000000000000 --- a/tools/workspace/stable_baselines3_internal/connection.patch +++ /dev/null @@ -1,11 +0,0 @@ -diff --git a/stable_baselines3/common/vec_env/subproc_vec_env.py b/stable_baselines3/common/vec_env/subproc_vec_env.py -index f723c71..43fa394 100644 ---- stable_baselines3/common/vec_env/subproc_vec_env.py -+++ stable_baselines3/common/vec_env/subproc_vec_env.py -@@ -1,5 +1,6 @@ - import multiprocessing as mp - from collections import OrderedDict -+from multiprocessing import connection - from typing import Any, Callable, List, Optional, Sequence, Tuple, Type, Union - - import gym diff --git a/tools/workspace/stable_baselines3_internal/no_torch.patch b/tools/workspace/stable_baselines3_internal/no_torch.patch index eae5b22bda90..671318774509 100644 --- a/tools/workspace/stable_baselines3_internal/no_torch.patch +++ b/tools/workspace/stable_baselines3_internal/no_torch.patch @@ -1,9 +1,10 @@ -diff --git a/stable_baselines3/__init__.py b/stable_baselines3/__init__.py -index d73f5f0..5797dba 100644 +diff --git stable_baselines3/__init__.py stable_baselines3/__init__.py +index 680e254..633ae10 100644 --- stable_baselines3/__init__.py +++ stable_baselines3/__init__.py -@@ -1,13 +1,19 @@ - import os +@@ -2,14 +2,20 @@ import os + + import numpy as np -from stable_baselines3.a2c import A2C -from stable_baselines3.common.utils import get_system_info @@ -28,16 +29,16 @@ index d73f5f0..5797dba 100644 +SAC = make_not_loaded("SAC") +TD3 = make_not_loaded("TD3") - # Read version from file - version_file = os.path.join(os.path.dirname(__file__), "version.txt") -diff --git a/stable_baselines3/common/env_checker.py b/stable_baselines3/common/env_checker.py -index 3b2c502..392d9df 100644 + # Small monkey patch so gym 0.21 is compatible with numpy >= 1.24 + # TODO: remove when upgrading to gym 0.26 +diff --git stable_baselines3/common/env_checker.py stable_baselines3/common/env_checker.py +index b71454b..af79679 100644 --- stable_baselines3/common/env_checker.py +++ stable_baselines3/common/env_checker.py @@ -6,7 +6,13 @@ import numpy as np from gym import spaces - from stable_baselines3.common.preprocessing import is_image_space_channels_first + from stable_baselines3.common.preprocessing import check_for_nested_spaces, is_image_space_channels_first -from stable_baselines3.common.vec_env import DummyVecEnv, VecCheckNan + +torch_available = False @@ -49,7 +50,7 @@ index 3b2c502..392d9df 100644 def _is_numpy_array_space(space: spaces.Space) -> bool: -@@ -87,6 +93,8 @@ def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, act +@@ -91,6 +97,8 @@ def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, act def _check_nan(env: gym.Env) -> None: """Check for Inf and NaN using the VecWrapper.""" @@ -58,8 +59,8 @@ index 3b2c502..392d9df 100644 vec_env = VecCheckNan(DummyVecEnv([lambda: env])) for _ in range(10): action = np.array([env.action_space.sample()]) -diff --git a/stable_baselines3/common/preprocessing.py b/stable_baselines3/common/preprocessing.py -index 01422aa..0603891 100644 +diff --git stable_baselines3/common/preprocessing.py stable_baselines3/common/preprocessing.py +index e280ed7..f648b88 100644 --- stable_baselines3/common/preprocessing.py +++ stable_baselines3/common/preprocessing.py @@ -1,10 +1,15 @@ @@ -80,7 +81,7 @@ index 01422aa..0603891 100644 def is_image_space_channels_first(observation_space: spaces.Box) -> bool: -@@ -83,10 +88,10 @@ def maybe_transpose(observation: np.ndarray, observation_space: spaces.Space) -> +@@ -90,10 +95,10 @@ def maybe_transpose(observation: np.ndarray, observation_space: spaces.Space) -> def preprocess_obs( @@ -93,4 +94,3 @@ index 01422aa..0603891 100644 """ Preprocess observation to be to a neural network. For images, it normalizes the values by dividing them by 255 (to have values in [0, 1]) - diff --git a/tools/workspace/stable_baselines3_internal/repository.bzl b/tools/workspace/stable_baselines3_internal/repository.bzl index 9f8cd2ecf1b1..48d7be206d86 100644 --- a/tools/workspace/stable_baselines3_internal/repository.bzl +++ b/tools/workspace/stable_baselines3_internal/repository.bzl @@ -6,11 +6,10 @@ def stable_baselines3_internal_repository( github_archive( name = name, repository = "DLR-RM/stable-baselines3", - commit = "v1.7.0", - sha256 = "f91a4a87f780b55f8c490dac177fb8474f6f18dd76ed488e78b3795f1d1c1bc4", # noqa + commit = "v1.8.0", + sha256 = "2ac876fc53546258008dbb1d249eb5b051bf9f0c8d1aae88c0c75af08c1c180d", # noqa build_file = ":package.BUILD.bazel", patches = [ - ":connection.patch", ":no_torch.patch", ], mirrors = mirrors,