Skip to content

Commit

Permalink
[Versioning] Gymnasium 1.0 incompatibility errors
Browse files Browse the repository at this point in the history
ghstack-source-id: 458e9762ec95b008667cce28a23268b77e421042
Pull Request resolved: #2484
  • Loading branch information
vmoens committed Oct 10, 2024
1 parent efa5745 commit e127d9a
Show file tree
Hide file tree
Showing 7 changed files with 102 additions and 31 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def _main(argv):
"pygame",
],
"dm_control": ["dm_control"],
"gym_continuous": ["gymnasium", "mujoco"],
"gym_continuous": ["gymnasium<1.0", "mujoco"],
"rendering": ["moviepy"],
"tests": ["pytest", "pyyaml", "pytest-instafail", "scipy"],
"utils": [
Expand Down
7 changes: 6 additions & 1 deletion test/_utils_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def _set_gym_environments(): # noqa: F811
_BREAKOUT_VERSIONED = "ALE/Breakout-v5"


@implement_for("gymnasium")
@implement_for("gymnasium", None, "1.0.0")
def _set_gym_environments(): # noqa: F811
global _CARTPOLE_VERSIONED, _HALFCHEETAH_VERSIONED, _PENDULUM_VERSIONED, _PONG_VERSIONED, _BREAKOUT_VERSIONED

Expand All @@ -132,6 +132,11 @@ def _set_gym_environments(): # noqa: F811
_BREAKOUT_VERSIONED = "ALE/Breakout-v5"


@implement_for("gymnasium", "1.0.0", None)
def _set_gym_environments(): # noqa: F811
raise ImportError


if _has_gym:
_set_gym_environments()

Expand Down
16 changes: 8 additions & 8 deletions test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def _make_spec( # noqa: F811
shape=batch_size,
)

@implement_for("gymnasium")
@implement_for("gymnasium", None, "1.0.0")
def _make_spec( # noqa: F811
self, batch_size, cat, cat_shape, multicat, multicat_shape
):
Expand Down Expand Up @@ -322,7 +322,7 @@ def test_gym_spec_cast_tuple_sequential(self, order):

# @pytest.mark.parametrize("order", ["seq_tuple", "tuple_seq"])
@pytest.mark.parametrize("order", ["tuple_seq"])
@implement_for("gymnasium")
@implement_for("gymnasium", None, "1.0.0")
def test_gym_spec_cast_tuple_sequential(self, order): # noqa: F811
with set_gym_backend("gymnasium"):
if order == "seq_tuple":
Expand Down Expand Up @@ -838,7 +838,7 @@ def info_reader(info, tensordict):
finally:
set_gym_backend(gb).set()

@implement_for("gymnasium")
@implement_for("gymnasium", None, "1.0.0")
def test_one_hot_and_categorical(self):
# tests that one-hot and categorical work ok when an integer is expected as action
cliff_walking = GymEnv("CliffWalking-v0", categorical_action_encoding=True)
Expand All @@ -857,7 +857,7 @@ def test_one_hot_and_categorical(self): # noqa: F811
# versions.
return

@implement_for("gymnasium")
@implement_for("gymnasium", None, "1.0.0")
@pytest.mark.parametrize(
"envname",
["HalfCheetah-v4", "CartPole-v1", "ALE/Pong-v5"]
Expand All @@ -883,7 +883,7 @@ def test_vecenvs_wrapper(self, envname):
assert env.batch_size == torch.Size([2])
check_env_specs(env)

@implement_for("gymnasium")
@implement_for("gymnasium", None, "1.0.0")
# this env has Dict-based observation which is a nice thing to test
@pytest.mark.parametrize(
"envname",
Expand Down Expand Up @@ -1045,7 +1045,7 @@ def test_gym_output_num(self, wrapper): # noqa: F811
finally:
set_gym_backend(gym).set()

@implement_for("gymnasium")
@implement_for("gymnasium", None, "1.0.0")
@pytest.mark.parametrize("wrapper", [True, False])
def test_gym_output_num(self, wrapper): # noqa: F811
# gym has 5 outputs, with truncation
Expand Down Expand Up @@ -1148,7 +1148,7 @@ def test_vecenvs_nan(self): # noqa: F811
del c
return

@implement_for("gymnasium")
@implement_for("gymnasium", None, "1.0.0")
def test_vecenvs_nan(self): # noqa: F811
# new versions of gym must never return nan for next values when there is a done state
torch.manual_seed(0)
Expand Down Expand Up @@ -1319,7 +1319,7 @@ def _make_gym_environment(env_name): # noqa: F811
return gym.make(env_name, render_mode="rgb_array")


@implement_for("gymnasium")
@implement_for("gymnasium", None, "1.0.0")
def _make_gym_environment(env_name): # noqa: F811
gym = gym_backend()
return gym.make(env_name, render_mode="rgb_array")
Expand Down
22 changes: 13 additions & 9 deletions test/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,8 @@ def test_implement_for_reset():
("0.9.0", "0.1.0", "0.21.0", True),
("0.19.99", "0.19.9", "0.21.0", True),
("0.19.99", None, "0.19.0", False),
("5.61.77", "0.21.0", None, True),
("5.61.77", None, "0.21.0", False),
("0.99.0", "0.21.0", None, True),
("0.99.0", None, "0.21.0", False),
],
)
def test_implement_for_check_versions(
Expand All @@ -189,17 +189,17 @@ def test_implement_for_check_versions(
@pytest.mark.parametrize(
"gymnasium_version, expected_from_version_gymnasium, expected_to_version_gymnasium",
[
("0.27.0", None, None),
("0.27.2", None, None),
("5.1.77", None, None),
("0.27.0", None, "1.0.0"),
("0.27.2", None, "1.0.0"),
("1.0.1", "1.0.0", None),
],
)
@pytest.mark.parametrize(
"gym_version, expected_from_version_gym, expected_to_version_gym",
[
("0.21.0", "0.21.0", None),
("0.22.0", "0.21.0", None),
("5.61.77", "0.21.0", None),
("0.99.0", "0.21.0", None),
("0.9.0", None, "0.21.0"),
("0.20.0", None, "0.21.0"),
("0.19.99", None, "0.21.0"),
Expand Down Expand Up @@ -228,6 +228,8 @@ def test_set_gym_environments(
import gymnasium

# look for the right function that should be called according to gym versions (and same for gymnasium)
expected_fn_gymnasium = None
expected_fn_gym = None
for impfor in implement_for._setters:
if impfor.fn.__name__ == "_set_gym_environments":
if (impfor.module_name, impfor.from_version, impfor.to_version) == (
Expand All @@ -242,20 +244,22 @@ def test_set_gym_environments(
expected_to_version_gymnasium,
):
expected_fn_gymnasium = impfor.fn
if expected_fn_gym is not None and expected_fn_gymnasium is not None:
break

with set_gym_backend(gymnasium):
assert (
_utils_internal._set_gym_environments == expected_fn_gymnasium
_utils_internal._set_gym_environments is expected_fn_gymnasium
), expected_fn_gym

with set_gym_backend(gym):
assert (
_utils_internal._set_gym_environments == expected_fn_gym
_utils_internal._set_gym_environments is expected_fn_gym
), expected_fn_gymnasium

with set_gym_backend(gymnasium):
assert (
_utils_internal._set_gym_environments == expected_fn_gymnasium
_utils_internal._set_gym_environments is expected_fn_gymnasium
), expected_fn_gym


Expand Down
2 changes: 1 addition & 1 deletion torchrl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ class implement_for:
... # More recent gym versions will return x + 2
... return x + 2
...
>>> @implement_for("gymnasium")
>>> @implement_for("gymnasium", None, "1.0.0")
>>> def fun(self, x):
... # If gymnasium is to be used instead of gym, x+3 will be returned
... return x + 3
Expand Down
14 changes: 11 additions & 3 deletions torchrl/envs/libs/_gym_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from torchrl._utils import implement_for
from torchrl.data import Composite
from torchrl.envs import step_mdp, TransformedEnv
from torchrl.envs.libs.gym import _torchrl_to_gym_spec_transform
from torchrl.envs.libs.gym import _torchrl_to_gym_spec_transform, GYMNASIUM_1_ERROR

_has_gym = importlib.util.find_spec("gym", None) is not None
_has_gymnasium = importlib.util.find_spec("gymnasium", None) is not None
Expand Down Expand Up @@ -125,7 +125,11 @@ def _action_keys(self):
import gymnasium

class _TorchRLGymnasiumWrapper(gymnasium.Env, _BaseGymWrapper):
@implement_for("gymnasium")
@implement_for("gymnasium", "1.0.0")
def step(self, action): # noqa: F811
raise ImportError(GYMNASIUM_1_ERROR)

@implement_for("gymnasium", None, "1.0.0")
def step(self, action): # noqa: F811
action_keys = self._action_keys
if len(action_keys) == 1:
Expand Down Expand Up @@ -153,7 +157,7 @@ def step(self, action): # noqa: F811
out = tree_map(lambda x: x.detach().cpu().numpy(), out)
return out

@implement_for("gymnasium")
@implement_for("gymnasium", None, "1.0.0")
def reset(self): # noqa: F811
self._tensordict = self.torchrl_env.reset()
observation = self._tensordict
Expand All @@ -167,6 +171,10 @@ def reset(self): # noqa: F811
out = tree_map(lambda x: x.detach().cpu().numpy(), out)
return out

@implement_for("gymnasium", "1.0.0")
def reset(self): # noqa: F811
raise ImportError(GYMNASIUM_1_ERROR)

else:

class _TorchRLGymnasiumWrapper:
Expand Down
70 changes: 62 additions & 8 deletions torchrl/envs/libs/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,22 @@
_has_minigrid = importlib.util.find_spec("minigrid") is not None


GYMNASIUM_1_ERROR = """RuntimeError: TorchRL does not support gymnasium 1.0 or later versions due to incompatible
changes in the Gym API.
Using gymnasium 1.0 with TorchRL would require significant modifications to your code and may result in:
* Inaccurate step counting, as the auto-reset feature can cause unpredictable numbers of steps to be executed.
* Potential data corruption, as the environment may require/produce garbage data during reset steps.
* Trajectory overlap during data collection.
* Increased computational overhead, as the library would need to handle the additional complexity of auto-resets.
* Manual filtering and boilerplate code to mitigate these issues, which would compromise the modularity and ease of
use of TorchRL.
To maintain the integrity and efficiency of our library, we cannot support this version of gymnasium at this time.
If you need to use gymnasium 1.0 or later, we recommend exploring alternative solutions or waiting for future updates
to TorchRL and gymnasium that may address this compatibility issue.
For more information, please refer to discussion https://github.com/pytorch/rl/discussions/2483 in torchrl.
"""


def _minigrid_lib():
assert _has_minigrid, "minigrid not found"
import minigrid
Expand Down Expand Up @@ -400,27 +416,37 @@ def _box_convert(spec, gym_spaces, shape): # noqa: F811
return gym_spaces.Box(low=low, high=high, shape=shape)


@implement_for("gymnasium")
@implement_for("gymnasium", None, "1.0.0")
def _box_convert(spec, gym_spaces, shape): # noqa: F811
low = spec.low.detach().cpu().numpy()
high = spec.high.detach().cpu().numpy()
return gym_spaces.Box(low=low, high=high, shape=shape)


@implement_for("gymnasium", "1.0.0")
def _box_convert(spec, gym_spaces, shape): # noqa: F811
raise ImportError(GYMNASIUM_1_ERROR)


@implement_for("gym", "0.21", None)
def _multidiscrete_convert(gym_spaces, spec):
return gym_spaces.multi_discrete.MultiDiscrete(
spec.nvec, dtype=torch_to_numpy_dtype_dict[spec.dtype]
)


@implement_for("gymnasium")
@implement_for("gymnasium", None, "1.0.0")
def _multidiscrete_convert(gym_spaces, spec): # noqa: F811
return gym_spaces.multi_discrete.MultiDiscrete(
spec.nvec, dtype=torch_to_numpy_dtype_dict[spec.dtype]
)


@implement_for("gymnasium", "1.0.0")
def _multidiscrete_convert(gym_spaces, spec): # noqa: F811
raise ImportError(GYMNASIUM_1_ERROR)


@implement_for("gym", None, "0.21")
def _multidiscrete_convert(gym_spaces, spec): # noqa: F811
return gym_spaces.multi_discrete.MultiDiscrete(spec.nvec)
Expand Down Expand Up @@ -519,12 +545,17 @@ def _get_gym_envs(): # noqa: F811
return gym.envs.registration.registry.keys()


@implement_for("gymnasium")
@implement_for("gymnasium", None, "1.0.0")
def _get_gym_envs(): # noqa: F811
gym = gym_backend()
return gym.envs.registration.registry.keys()


@implement_for("gymnasium", "1.0.0")
def _get_gym_envs(): # noqa: F811
raise ImportError(GYMNASIUM_1_ERROR)


def _is_from_pixels(env):
observation_spec = env.observation_space
try:
Expand Down Expand Up @@ -835,7 +866,7 @@ def _get_batch_size(self, env):
batch_size = self.batch_size
return batch_size

@implement_for("gymnasium") # gymnasium wants the unwrapped env
@implement_for("gymnasium", None, "1.0.0") # gymnasium wants the unwrapped env
def _get_batch_size(self, env): # noqa: F811
env_unwrapped = env.unwrapped
if hasattr(env_unwrapped, "num_envs"):
Expand All @@ -844,6 +875,10 @@ def _get_batch_size(self, env): # noqa: F811
batch_size = self.batch_size
return batch_size

@implement_for("gymnasium", "1.0.0")
def _get_batch_size(self, env): # noqa: F811
raise ImportError(GYMNASIUM_1_ERROR)

def _check_kwargs(self, kwargs: Dict):
if "env" not in kwargs:
raise TypeError("Could not find environment key 'env' in kwargs.")
Expand Down Expand Up @@ -920,7 +955,11 @@ def _build_gym_env(self, env, pixels_only): # noqa: F811

return LegacyPixelObservationWrapper(env, pixels_only=pixels_only)

@implement_for("gymnasium")
@implement_for("gymnasium", "1.0.0")
def _build_gym_env(self, env, pixels_only): # noqa: F811
raise ImportError(GYMNASIUM_1_ERROR)

@implement_for("gymnasium", None, "1.0.0")
def _build_gym_env(self, env, pixels_only): # noqa: F811
compatibility = gym_backend("wrappers.compatibility")
pixel_observation = gym_backend("wrappers.pixel_observation")
Expand Down Expand Up @@ -985,7 +1024,11 @@ def _set_seed_initial(self, seed: int) -> None: # noqa: F811
except AttributeError as err2:
raise err from err2

@implement_for("gymnasium")
@implement_for("gymnasium", "1.0.0")
def _set_seed_initial(self, seed: int) -> None: # noqa: F811
raise ImportError(GYMNASIUM_1_ERROR)

@implement_for("gymnasium", None, "1.0.0")
def _set_seed_initial(self, seed: int) -> None: # noqa: F811
try:
self.reset(seed=seed)
Expand All @@ -1003,7 +1046,11 @@ def _reward_space(self, env):
if hasattr(env, "reward_space") and env.reward_space is not None:
return env.reward_space

@implement_for("gymnasium")
@implement_for("gymnasium", "1.0.0")
def _reward_space(self, env): # noqa: F811
raise ImportError(GYMNASIUM_1_ERROR)

@implement_for("gymnasium", None, "1.0.0")
def _reward_space(self, env): # noqa: F811
env = env.unwrapped
if hasattr(env, "reward_space") and env.reward_space is not None:
Expand Down Expand Up @@ -1397,7 +1444,14 @@ def _set_gym_args( # noqa: F811
) -> None:
kwargs.setdefault("disable_env_checker", True)

@implement_for("gymnasium")
@implement_for("gymnasium", "1.0.0")
def _set_gym_args( # noqa: F811
self,
kwargs,
) -> None:
raise ImportError(GYMNASIUM_1_ERROR)

@implement_for("gymnasium", None, "1.0.0")
def _set_gym_args( # noqa: F811
self,
kwargs,
Expand Down

0 comments on commit e127d9a

Please sign in to comment.