diff --git a/setup.py b/setup.py index fad0597cc02..d37c179600f 100644 --- a/setup.py +++ b/setup.py @@ -203,7 +203,7 @@ def _main(argv): "pygame", ], "dm_control": ["dm_control"], - "gym_continuous": ["gymnasium", "mujoco"], + "gym_continuous": ["gymnasium<1.0", "mujoco"], "rendering": ["moviepy"], "tests": ["pytest", "pyyaml", "pytest-instafail", "scipy"], "utils": [ diff --git a/test/_utils_internal.py b/test/_utils_internal.py index 5c41b4edb99..51535afa606 100644 --- a/test/_utils_internal.py +++ b/test/_utils_internal.py @@ -121,7 +121,7 @@ def _set_gym_environments(): # noqa: F811 _BREAKOUT_VERSIONED = "ALE/Breakout-v5" -@implement_for("gymnasium") +@implement_for("gymnasium", None, "1.0.0") def _set_gym_environments(): # noqa: F811 global _CARTPOLE_VERSIONED, _HALFCHEETAH_VERSIONED, _PENDULUM_VERSIONED, _PONG_VERSIONED, _BREAKOUT_VERSIONED @@ -132,6 +132,11 @@ def _set_gym_environments(): # noqa: F811 _BREAKOUT_VERSIONED = "ALE/Breakout-v5" +@implement_for("gymnasium", "1.0.0", None) +def _set_gym_environments(): # noqa: F811 + raise ImportError + + if _has_gym: _set_gym_environments() diff --git a/test/test_libs.py b/test/test_libs.py index 6fc2979607d..363d111db46 100644 --- a/test/test_libs.py +++ b/test/test_libs.py @@ -277,7 +277,7 @@ def _make_spec( # noqa: F811 shape=batch_size, ) - @implement_for("gymnasium") + @implement_for("gymnasium", None, "1.0.0") def _make_spec( # noqa: F811 self, batch_size, cat, cat_shape, multicat, multicat_shape ): @@ -322,7 +322,7 @@ def test_gym_spec_cast_tuple_sequential(self, order): # @pytest.mark.parametrize("order", ["seq_tuple", "tuple_seq"]) @pytest.mark.parametrize("order", ["tuple_seq"]) - @implement_for("gymnasium") + @implement_for("gymnasium", None, "1.0.0") def test_gym_spec_cast_tuple_sequential(self, order): # noqa: F811 with set_gym_backend("gymnasium"): if order == "seq_tuple": @@ -838,7 +838,7 @@ def info_reader(info, tensordict): finally: set_gym_backend(gb).set() - @implement_for("gymnasium") + @implement_for("gymnasium", None, "1.0.0") def test_one_hot_and_categorical(self): # tests that one-hot and categorical work ok when an integer is expected as action cliff_walking = GymEnv("CliffWalking-v0", categorical_action_encoding=True) @@ -857,7 +857,7 @@ def test_one_hot_and_categorical(self): # noqa: F811 # versions. return - @implement_for("gymnasium") + @implement_for("gymnasium", None, "1.0.0") @pytest.mark.parametrize( "envname", ["HalfCheetah-v4", "CartPole-v1", "ALE/Pong-v5"] @@ -883,7 +883,7 @@ def test_vecenvs_wrapper(self, envname): assert env.batch_size == torch.Size([2]) check_env_specs(env) - @implement_for("gymnasium") + @implement_for("gymnasium", None, "1.0.0") # this env has Dict-based observation which is a nice thing to test @pytest.mark.parametrize( "envname", @@ -1045,7 +1045,7 @@ def test_gym_output_num(self, wrapper): # noqa: F811 finally: set_gym_backend(gym).set() - @implement_for("gymnasium") + @implement_for("gymnasium", None, "1.0.0") @pytest.mark.parametrize("wrapper", [True, False]) def test_gym_output_num(self, wrapper): # noqa: F811 # gym has 5 outputs, with truncation @@ -1148,7 +1148,7 @@ def test_vecenvs_nan(self): # noqa: F811 del c return - @implement_for("gymnasium") + @implement_for("gymnasium", None, "1.0.0") def test_vecenvs_nan(self): # noqa: F811 # new versions of gym must never return nan for next values when there is a done state torch.manual_seed(0) @@ -1319,7 +1319,7 @@ def _make_gym_environment(env_name): # noqa: F811 return gym.make(env_name, render_mode="rgb_array") -@implement_for("gymnasium") +@implement_for("gymnasium", None, "1.0.0") def _make_gym_environment(env_name): # noqa: F811 gym = gym_backend() return gym.make(env_name, render_mode="rgb_array") diff --git a/test/test_utils.py b/test/test_utils.py index c2ce2eae6b9..f94b776a31b 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -174,8 +174,8 @@ def test_implement_for_reset(): ("0.9.0", "0.1.0", "0.21.0", True), ("0.19.99", "0.19.9", "0.21.0", True), ("0.19.99", None, "0.19.0", False), - ("5.61.77", "0.21.0", None, True), - ("5.61.77", None, "0.21.0", False), + ("0.99.0", "0.21.0", None, True), + ("0.99.0", None, "0.21.0", False), ], ) def test_implement_for_check_versions( @@ -189,9 +189,9 @@ def test_implement_for_check_versions( @pytest.mark.parametrize( "gymnasium_version, expected_from_version_gymnasium, expected_to_version_gymnasium", [ - ("0.27.0", None, None), - ("0.27.2", None, None), - ("5.1.77", None, None), + ("0.27.0", None, "1.0.0"), + ("0.27.2", None, "1.0.0"), + ("1.0.1", "1.0.0", None), ], ) @pytest.mark.parametrize( @@ -199,7 +199,7 @@ def test_implement_for_check_versions( [ ("0.21.0", "0.21.0", None), ("0.22.0", "0.21.0", None), - ("5.61.77", "0.21.0", None), + ("0.99.0", "0.21.0", None), ("0.9.0", None, "0.21.0"), ("0.20.0", None, "0.21.0"), ("0.19.99", None, "0.21.0"), @@ -228,6 +228,8 @@ def test_set_gym_environments( import gymnasium # look for the right function that should be called according to gym versions (and same for gymnasium) + expected_fn_gymnasium = None + expected_fn_gym = None for impfor in implement_for._setters: if impfor.fn.__name__ == "_set_gym_environments": if (impfor.module_name, impfor.from_version, impfor.to_version) == ( @@ -242,20 +244,22 @@ def test_set_gym_environments( expected_to_version_gymnasium, ): expected_fn_gymnasium = impfor.fn + if expected_fn_gym is not None and expected_fn_gymnasium is not None: + break with set_gym_backend(gymnasium): assert ( - _utils_internal._set_gym_environments == expected_fn_gymnasium + _utils_internal._set_gym_environments is expected_fn_gymnasium ), expected_fn_gym with set_gym_backend(gym): assert ( - _utils_internal._set_gym_environments == expected_fn_gym + _utils_internal._set_gym_environments is expected_fn_gym ), expected_fn_gymnasium with set_gym_backend(gymnasium): assert ( - _utils_internal._set_gym_environments == expected_fn_gymnasium + _utils_internal._set_gym_environments is expected_fn_gymnasium ), expected_fn_gym diff --git a/torchrl/_utils.py b/torchrl/_utils.py index 0bfdc7b07ce..3af44ee0ed7 100644 --- a/torchrl/_utils.py +++ b/torchrl/_utils.py @@ -269,7 +269,7 @@ class implement_for: ... # More recent gym versions will return x + 2 ... return x + 2 ... - >>> @implement_for("gymnasium") + >>> @implement_for("gymnasium", None, "1.0.0") >>> def fun(self, x): ... # If gymnasium is to be used instead of gym, x+3 will be returned ... return x + 3 diff --git a/torchrl/envs/libs/_gym_utils.py b/torchrl/envs/libs/_gym_utils.py index 6200987c5a8..b95bfb335c6 100644 --- a/torchrl/envs/libs/_gym_utils.py +++ b/torchrl/envs/libs/_gym_utils.py @@ -14,7 +14,7 @@ from torchrl._utils import implement_for from torchrl.data import Composite from torchrl.envs import step_mdp, TransformedEnv -from torchrl.envs.libs.gym import _torchrl_to_gym_spec_transform +from torchrl.envs.libs.gym import _torchrl_to_gym_spec_transform, GYMNASIUM_1_ERROR _has_gym = importlib.util.find_spec("gym", None) is not None _has_gymnasium = importlib.util.find_spec("gymnasium", None) is not None @@ -125,7 +125,11 @@ def _action_keys(self): import gymnasium class _TorchRLGymnasiumWrapper(gymnasium.Env, _BaseGymWrapper): - @implement_for("gymnasium") + @implement_for("gymnasium", "1.0.0") + def step(self, action): # noqa: F811 + raise ImportError(GYMNASIUM_1_ERROR) + + @implement_for("gymnasium", None, "1.0.0") def step(self, action): # noqa: F811 action_keys = self._action_keys if len(action_keys) == 1: @@ -153,7 +157,7 @@ def step(self, action): # noqa: F811 out = tree_map(lambda x: x.detach().cpu().numpy(), out) return out - @implement_for("gymnasium") + @implement_for("gymnasium", None, "1.0.0") def reset(self): # noqa: F811 self._tensordict = self.torchrl_env.reset() observation = self._tensordict @@ -167,6 +171,10 @@ def reset(self): # noqa: F811 out = tree_map(lambda x: x.detach().cpu().numpy(), out) return out + @implement_for("gymnasium", "1.0.0") + def reset(self): # noqa: F811 + raise ImportError(GYMNASIUM_1_ERROR) + else: class _TorchRLGymnasiumWrapper: diff --git a/torchrl/envs/libs/gym.py b/torchrl/envs/libs/gym.py index 61960d1a40d..dfe0db92230 100644 --- a/torchrl/envs/libs/gym.py +++ b/torchrl/envs/libs/gym.py @@ -59,6 +59,22 @@ _has_minigrid = importlib.util.find_spec("minigrid") is not None +GYMNASIUM_1_ERROR = """RuntimeError: TorchRL does not support gymnasium 1.0 or later versions due to incompatible +changes in the Gym API. +Using gymnasium 1.0 with TorchRL would require significant modifications to your code and may result in: +* Inaccurate step counting, as the auto-reset feature can cause unpredictable numbers of steps to be executed. +* Potential data corruption, as the environment may require/produce garbage data during reset steps. +* Trajectory overlap during data collection. +* Increased computational overhead, as the library would need to handle the additional complexity of auto-resets. +* Manual filtering and boilerplate code to mitigate these issues, which would compromise the modularity and ease of +use of TorchRL. +To maintain the integrity and efficiency of our library, we cannot support this version of gymnasium at this time. +If you need to use gymnasium 1.0 or later, we recommend exploring alternative solutions or waiting for future updates +to TorchRL and gymnasium that may address this compatibility issue. +For more information, please refer to discussion https://github.com/pytorch/rl/discussions/2483 in torchrl. +""" + + def _minigrid_lib(): assert _has_minigrid, "minigrid not found" import minigrid @@ -400,13 +416,18 @@ def _box_convert(spec, gym_spaces, shape): # noqa: F811 return gym_spaces.Box(low=low, high=high, shape=shape) -@implement_for("gymnasium") +@implement_for("gymnasium", None, "1.0.0") def _box_convert(spec, gym_spaces, shape): # noqa: F811 low = spec.low.detach().cpu().numpy() high = spec.high.detach().cpu().numpy() return gym_spaces.Box(low=low, high=high, shape=shape) +@implement_for("gymnasium", "1.0.0") +def _box_convert(spec, gym_spaces, shape): # noqa: F811 + raise ImportError(GYMNASIUM_1_ERROR) + + @implement_for("gym", "0.21", None) def _multidiscrete_convert(gym_spaces, spec): return gym_spaces.multi_discrete.MultiDiscrete( @@ -414,13 +435,18 @@ def _multidiscrete_convert(gym_spaces, spec): ) -@implement_for("gymnasium") +@implement_for("gymnasium", None, "1.0.0") def _multidiscrete_convert(gym_spaces, spec): # noqa: F811 return gym_spaces.multi_discrete.MultiDiscrete( spec.nvec, dtype=torch_to_numpy_dtype_dict[spec.dtype] ) +@implement_for("gymnasium", "1.0.0") +def _multidiscrete_convert(gym_spaces, spec): # noqa: F811 + raise ImportError(GYMNASIUM_1_ERROR) + + @implement_for("gym", None, "0.21") def _multidiscrete_convert(gym_spaces, spec): # noqa: F811 return gym_spaces.multi_discrete.MultiDiscrete(spec.nvec) @@ -519,12 +545,17 @@ def _get_gym_envs(): # noqa: F811 return gym.envs.registration.registry.keys() -@implement_for("gymnasium") +@implement_for("gymnasium", None, "1.0.0") def _get_gym_envs(): # noqa: F811 gym = gym_backend() return gym.envs.registration.registry.keys() +@implement_for("gymnasium", "1.0.0") +def _get_gym_envs(): # noqa: F811 + raise ImportError(GYMNASIUM_1_ERROR) + + def _is_from_pixels(env): observation_spec = env.observation_space try: @@ -835,7 +866,7 @@ def _get_batch_size(self, env): batch_size = self.batch_size return batch_size - @implement_for("gymnasium") # gymnasium wants the unwrapped env + @implement_for("gymnasium", None, "1.0.0") # gymnasium wants the unwrapped env def _get_batch_size(self, env): # noqa: F811 env_unwrapped = env.unwrapped if hasattr(env_unwrapped, "num_envs"): @@ -844,6 +875,10 @@ def _get_batch_size(self, env): # noqa: F811 batch_size = self.batch_size return batch_size + @implement_for("gymnasium", "1.0.0") + def _get_batch_size(self, env): # noqa: F811 + raise ImportError(GYMNASIUM_1_ERROR) + def _check_kwargs(self, kwargs: Dict): if "env" not in kwargs: raise TypeError("Could not find environment key 'env' in kwargs.") @@ -920,7 +955,11 @@ def _build_gym_env(self, env, pixels_only): # noqa: F811 return LegacyPixelObservationWrapper(env, pixels_only=pixels_only) - @implement_for("gymnasium") + @implement_for("gymnasium", "1.0.0") + def _build_gym_env(self, env, pixels_only): # noqa: F811 + raise ImportError(GYMNASIUM_1_ERROR) + + @implement_for("gymnasium", None, "1.0.0") def _build_gym_env(self, env, pixels_only): # noqa: F811 compatibility = gym_backend("wrappers.compatibility") pixel_observation = gym_backend("wrappers.pixel_observation") @@ -985,7 +1024,11 @@ def _set_seed_initial(self, seed: int) -> None: # noqa: F811 except AttributeError as err2: raise err from err2 - @implement_for("gymnasium") + @implement_for("gymnasium", "1.0.0") + def _set_seed_initial(self, seed: int) -> None: # noqa: F811 + raise ImportError(GYMNASIUM_1_ERROR) + + @implement_for("gymnasium", None, "1.0.0") def _set_seed_initial(self, seed: int) -> None: # noqa: F811 try: self.reset(seed=seed) @@ -1003,7 +1046,11 @@ def _reward_space(self, env): if hasattr(env, "reward_space") and env.reward_space is not None: return env.reward_space - @implement_for("gymnasium") + @implement_for("gymnasium", "1.0.0") + def _reward_space(self, env): # noqa: F811 + raise ImportError(GYMNASIUM_1_ERROR) + + @implement_for("gymnasium", None, "1.0.0") def _reward_space(self, env): # noqa: F811 env = env.unwrapped if hasattr(env, "reward_space") and env.reward_space is not None: @@ -1397,7 +1444,14 @@ def _set_gym_args( # noqa: F811 ) -> None: kwargs.setdefault("disable_env_checker", True) - @implement_for("gymnasium") + @implement_for("gymnasium", "1.0.0") + def _set_gym_args( # noqa: F811 + self, + kwargs, + ) -> None: + raise ImportError(GYMNASIUM_1_ERROR) + + @implement_for("gymnasium", None, "1.0.0") def _set_gym_args( # noqa: F811 self, kwargs,