Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[workspace] Upgrade stable_baselines3_internal to latest release v1.8.0 #19187

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 0 additions & 11 deletions tools/workspace/stable_baselines3_internal/connection.patch

This file was deleted.

28 changes: 14 additions & 14 deletions tools/workspace/stable_baselines3_internal/no_torch.patch
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
diff --git a/stable_baselines3/__init__.py b/stable_baselines3/__init__.py
index d73f5f0..5797dba 100644
diff --git stable_baselines3/__init__.py stable_baselines3/__init__.py
index 680e254..633ae10 100644
--- stable_baselines3/__init__.py
+++ stable_baselines3/__init__.py
@@ -1,13 +1,19 @@
import os
@@ -2,14 +2,20 @@ import os

import numpy as np

-from stable_baselines3.a2c import A2C
-from stable_baselines3.common.utils import get_system_info
Expand All @@ -28,16 +29,16 @@ index d73f5f0..5797dba 100644
+SAC = make_not_loaded("SAC")
+TD3 = make_not_loaded("TD3")

# Read version from file
version_file = os.path.join(os.path.dirname(__file__), "version.txt")
diff --git a/stable_baselines3/common/env_checker.py b/stable_baselines3/common/env_checker.py
index 3b2c502..392d9df 100644
# Small monkey patch so gym 0.21 is compatible with numpy >= 1.24
# TODO: remove when upgrading to gym 0.26
diff --git stable_baselines3/common/env_checker.py stable_baselines3/common/env_checker.py
index b71454b..af79679 100644
--- stable_baselines3/common/env_checker.py
+++ stable_baselines3/common/env_checker.py
@@ -6,7 +6,13 @@ import numpy as np
from gym import spaces

from stable_baselines3.common.preprocessing import is_image_space_channels_first
from stable_baselines3.common.preprocessing import check_for_nested_spaces, is_image_space_channels_first
-from stable_baselines3.common.vec_env import DummyVecEnv, VecCheckNan
+
+torch_available = False
Expand All @@ -49,7 +50,7 @@ index 3b2c502..392d9df 100644


def _is_numpy_array_space(space: spaces.Space) -> bool:
@@ -87,6 +93,8 @@ def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, act
@@ -91,6 +97,8 @@ def _check_unsupported_spaces(env: gym.Env, observation_space: spaces.Space, act

def _check_nan(env: gym.Env) -> None:
"""Check for Inf and NaN using the VecWrapper."""
Expand All @@ -58,8 +59,8 @@ index 3b2c502..392d9df 100644
vec_env = VecCheckNan(DummyVecEnv([lambda: env]))
for _ in range(10):
action = np.array([env.action_space.sample()])
diff --git a/stable_baselines3/common/preprocessing.py b/stable_baselines3/common/preprocessing.py
index 01422aa..0603891 100644
diff --git stable_baselines3/common/preprocessing.py stable_baselines3/common/preprocessing.py
index e280ed7..f648b88 100644
--- stable_baselines3/common/preprocessing.py
+++ stable_baselines3/common/preprocessing.py
@@ -1,10 +1,15 @@
Expand All @@ -80,7 +81,7 @@ index 01422aa..0603891 100644


def is_image_space_channels_first(observation_space: spaces.Box) -> bool:
@@ -83,10 +88,10 @@ def maybe_transpose(observation: np.ndarray, observation_space: spaces.Space) ->
@@ -90,10 +95,10 @@ def maybe_transpose(observation: np.ndarray, observation_space: spaces.Space) ->


def preprocess_obs(
Expand All @@ -93,4 +94,3 @@ index 01422aa..0603891 100644
"""
Preprocess observation to be to a neural network.
For images, it normalizes the values by dividing them by 255 (to have values in [0, 1])

5 changes: 2 additions & 3 deletions tools/workspace/stable_baselines3_internal/repository.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@ def stable_baselines3_internal_repository(
github_archive(
name = name,
repository = "DLR-RM/stable-baselines3",
commit = "v1.7.0",
sha256 = "f91a4a87f780b55f8c490dac177fb8474f6f18dd76ed488e78b3795f1d1c1bc4", # noqa
commit = "v1.8.0",
sha256 = "2ac876fc53546258008dbb1d249eb5b051bf9f0c8d1aae88c0c75af08c1c180d", # noqa
build_file = ":package.BUILD.bazel",
patches = [
":connection.patch",
":no_torch.patch",
],
mirrors = mirrors,
Expand Down