Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Versioning] Gymnasium 1.0 incompatibility errors #2484

Merged
merged 2 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion test/_utils_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def _set_gym_environments(): # noqa: F811
_BREAKOUT_VERSIONED = "ALE/Breakout-v5"


@implement_for("gymnasium")
@implement_for("gymnasium", None, "1.0.0")
def _set_gym_environments(): # noqa: F811
global _CARTPOLE_VERSIONED, _HALFCHEETAH_VERSIONED, _PENDULUM_VERSIONED, _PONG_VERSIONED, _BREAKOUT_VERSIONED

Expand Down
16 changes: 8 additions & 8 deletions test/test_libs.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def _make_spec( # noqa: F811
shape=batch_size,
)

@implement_for("gymnasium")
@implement_for("gymnasium", None, "1.0.0")
def _make_spec( # noqa: F811
self, batch_size, cat, cat_shape, multicat, multicat_shape
):
Expand Down Expand Up @@ -322,7 +322,7 @@ def test_gym_spec_cast_tuple_sequential(self, order):

# @pytest.mark.parametrize("order", ["seq_tuple", "tuple_seq"])
@pytest.mark.parametrize("order", ["tuple_seq"])
@implement_for("gymnasium")
@implement_for("gymnasium", None, "1.0.0")
def test_gym_spec_cast_tuple_sequential(self, order): # noqa: F811
with set_gym_backend("gymnasium"):
if order == "seq_tuple":
Expand Down Expand Up @@ -838,7 +838,7 @@ def info_reader(info, tensordict):
finally:
set_gym_backend(gb).set()

@implement_for("gymnasium")
@implement_for("gymnasium", None, "1.0.0")
def test_one_hot_and_categorical(self):
# tests that one-hot and categorical work ok when an integer is expected as action
cliff_walking = GymEnv("CliffWalking-v0", categorical_action_encoding=True)
Expand All @@ -857,7 +857,7 @@ def test_one_hot_and_categorical(self): # noqa: F811
# versions.
return

@implement_for("gymnasium")
@implement_for("gymnasium", None, "1.0.0")
@pytest.mark.parametrize(
"envname",
["HalfCheetah-v4", "CartPole-v1", "ALE/Pong-v5"]
Expand All @@ -883,7 +883,7 @@ def test_vecenvs_wrapper(self, envname):
assert env.batch_size == torch.Size([2])
check_env_specs(env)

@implement_for("gymnasium")
@implement_for("gymnasium", None, "1.0.0")
# this env has Dict-based observation which is a nice thing to test
@pytest.mark.parametrize(
"envname",
Expand Down Expand Up @@ -1045,7 +1045,7 @@ def test_gym_output_num(self, wrapper): # noqa: F811
finally:
set_gym_backend(gym).set()

@implement_for("gymnasium")
@implement_for("gymnasium", None, "1.0.0")
@pytest.mark.parametrize("wrapper", [True, False])
def test_gym_output_num(self, wrapper): # noqa: F811
# gym has 5 outputs, with truncation
Expand Down Expand Up @@ -1148,7 +1148,7 @@ def test_vecenvs_nan(self): # noqa: F811
del c
return

@implement_for("gymnasium")
@implement_for("gymnasium", None, "1.0.0")
def test_vecenvs_nan(self): # noqa: F811
# new versions of gym must never return nan for next values when there is a done state
torch.manual_seed(0)
Expand Down Expand Up @@ -1319,7 +1319,7 @@ def _make_gym_environment(env_name): # noqa: F811
return gym.make(env_name, render_mode="rgb_array")


@implement_for("gymnasium")
@implement_for("gymnasium", None, "1.0.0")
def _make_gym_environment(env_name): # noqa: F811
gym = gym_backend()
return gym.make(env_name, render_mode="rgb_array")
Expand Down
2 changes: 1 addition & 1 deletion torchrl/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ class implement_for:
... # More recent gym versions will return x + 2
... return x + 2
...
>>> @implement_for("gymnasium")
>>> @implement_for("gymnasium", None, "1.0.0")
>>> def fun(self, x):
... # If gymnasium is to be used instead of gym, x+3 will be returned
... return x + 3
Expand Down
14 changes: 11 additions & 3 deletions torchrl/envs/libs/_gym_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from torchrl._utils import implement_for
from torchrl.data import Composite
from torchrl.envs import step_mdp, TransformedEnv
from torchrl.envs.libs.gym import _torchrl_to_gym_spec_transform
from torchrl.envs.libs.gym import _torchrl_to_gym_spec_transform, GYMNASIUM_1_ERROR

_has_gym = importlib.util.find_spec("gym", None) is not None
_has_gymnasium = importlib.util.find_spec("gymnasium", None) is not None
Expand Down Expand Up @@ -125,7 +125,11 @@ def _action_keys(self):
import gymnasium

class _TorchRLGymnasiumWrapper(gymnasium.Env, _BaseGymWrapper):
@implement_for("gymnasium")
@implement_for("gymnasium", "1.0.0")
def step(self, action): # noqa: F811
raise ImportError(GYMNASIUM_1_ERROR)

@implement_for("gymnasium", None, "1.0.0")
def step(self, action): # noqa: F811
action_keys = self._action_keys
if len(action_keys) == 1:
Expand Down Expand Up @@ -153,7 +157,7 @@ def step(self, action): # noqa: F811
out = tree_map(lambda x: x.detach().cpu().numpy(), out)
return out

@implement_for("gymnasium")
@implement_for("gymnasium", None, "1.0.0")
def reset(self): # noqa: F811
self._tensordict = self.torchrl_env.reset()
observation = self._tensordict
Expand All @@ -167,6 +171,10 @@ def reset(self): # noqa: F811
out = tree_map(lambda x: x.detach().cpu().numpy(), out)
return out

@implement_for("gymnasium", "1.0.0")
def reset(self): # noqa: F811
raise ImportError(GYMNASIUM_1_ERROR)

else:

class _TorchRLGymnasiumWrapper:
Expand Down
70 changes: 62 additions & 8 deletions torchrl/envs/libs/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,22 @@
_has_minigrid = importlib.util.find_spec("minigrid") is not None


GYMNASIUM_1_ERROR = """RuntimeError: TorchRL does not support gymnasium 1.0 or later versions due to incompatible
changes in the Gym API.
Using gymnasium 1.0 with TorchRL would require significant modifications to your code and may result in:
* Inaccurate step counting, as the auto-reset feature can cause unpredictable numbers of steps to be executed.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* Inaccurate step counting, as the auto-reset feature can cause unpredictable numbers of steps to be executed.
* Inaccurate step counting, the number of step calls to obtain the same amount of data using the autoreset feature will depend on the done frequency of the enviornment.

* Potential data corruption, as the environment may require/produce garbage data during reset steps.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* Potential data corruption, as the environment may require/produce garbage data during reset steps.
* Potential data corruption, as the environment may require/produce invalid data during reset steps.

* Trajectory overlap during data collection.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe explain this one too

* Increased computational overhead, as the library would need to handle the additional complexity of auto-resets.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* Increased computational overhead, as the library would need to handle the additional complexity of auto-resets.
* Increased computational overhead, as auto-resets introduce additional computational complexity.

* Manual filtering and boilerplate code to mitigate these issues, which would compromise the modularity and ease of
use of TorchRL.
To maintain the integrity and efficiency of our library, we cannot support this version of gymnasium at this time.
If you need to use gymnasium 1.0 or later, we recommend exploring alternative solutions or waiting for future updates
to TorchRL and gymnasium that may address this compatibility issue.
For more information, please refer to discussion https://github.com/pytorch/rl/discussions/2483 in torchrl.
"""


def _minigrid_lib():
assert _has_minigrid, "minigrid not found"
import minigrid
Expand Down Expand Up @@ -400,27 +416,37 @@ def _box_convert(spec, gym_spaces, shape): # noqa: F811
return gym_spaces.Box(low=low, high=high, shape=shape)


@implement_for("gymnasium")
@implement_for("gymnasium", None, "1.0.0")
def _box_convert(spec, gym_spaces, shape): # noqa: F811
low = spec.low.detach().cpu().numpy()
high = spec.high.detach().cpu().numpy()
return gym_spaces.Box(low=low, high=high, shape=shape)


@implement_for("gymnasium", "1.0.0")
def _box_convert(spec, gym_spaces, shape): # noqa: F811
raise ImportError(GYMNASIUM_1_ERROR)


@implement_for("gym", "0.21", None)
def _multidiscrete_convert(gym_spaces, spec):
return gym_spaces.multi_discrete.MultiDiscrete(
spec.nvec, dtype=torch_to_numpy_dtype_dict[spec.dtype]
)


@implement_for("gymnasium")
@implement_for("gymnasium", None, "1.0.0")
def _multidiscrete_convert(gym_spaces, spec): # noqa: F811
return gym_spaces.multi_discrete.MultiDiscrete(
spec.nvec, dtype=torch_to_numpy_dtype_dict[spec.dtype]
)


@implement_for("gymnasium", "1.0.0")
def _multidiscrete_convert(gym_spaces, spec): # noqa: F811
raise ImportError(GYMNASIUM_1_ERROR)


@implement_for("gym", None, "0.21")
def _multidiscrete_convert(gym_spaces, spec): # noqa: F811
return gym_spaces.multi_discrete.MultiDiscrete(spec.nvec)
Expand Down Expand Up @@ -519,12 +545,17 @@ def _get_gym_envs(): # noqa: F811
return gym.envs.registration.registry.keys()


@implement_for("gymnasium")
@implement_for("gymnasium", None, "1.0.0")
def _get_gym_envs(): # noqa: F811
gym = gym_backend()
return gym.envs.registration.registry.keys()


@implement_for("gymnasium", "1.0.0")
def _get_gym_envs(): # noqa: F811
raise ImportError(GYMNASIUM_1_ERROR)


def _is_from_pixels(env):
observation_spec = env.observation_space
try:
Expand Down Expand Up @@ -835,7 +866,7 @@ def _get_batch_size(self, env):
batch_size = self.batch_size
return batch_size

@implement_for("gymnasium") # gymnasium wants the unwrapped env
@implement_for("gymnasium", None, "1.0.0") # gymnasium wants the unwrapped env
def _get_batch_size(self, env): # noqa: F811
env_unwrapped = env.unwrapped
if hasattr(env_unwrapped, "num_envs"):
Expand All @@ -844,6 +875,10 @@ def _get_batch_size(self, env): # noqa: F811
batch_size = self.batch_size
return batch_size

@implement_for("gymnasium", "1.0.0")
def _get_batch_size(self, env): # noqa: F811
raise ImportError(GYMNASIUM_1_ERROR)

def _check_kwargs(self, kwargs: Dict):
if "env" not in kwargs:
raise TypeError("Could not find environment key 'env' in kwargs.")
Expand Down Expand Up @@ -920,7 +955,11 @@ def _build_gym_env(self, env, pixels_only): # noqa: F811

return LegacyPixelObservationWrapper(env, pixels_only=pixels_only)

@implement_for("gymnasium")
@implement_for("gymnasium", "1.0.0")
def _build_gym_env(self, env, pixels_only): # noqa: F811
raise ImportError(GYMNASIUM_1_ERROR)

@implement_for("gymnasium", None, "1.0.0")
def _build_gym_env(self, env, pixels_only): # noqa: F811
compatibility = gym_backend("wrappers.compatibility")
pixel_observation = gym_backend("wrappers.pixel_observation")
Expand Down Expand Up @@ -985,7 +1024,11 @@ def _set_seed_initial(self, seed: int) -> None: # noqa: F811
except AttributeError as err2:
raise err from err2

@implement_for("gymnasium")
@implement_for("gymnasium", "1.0.0")
def _set_seed_initial(self, seed: int) -> None: # noqa: F811
raise ImportError(GYMNASIUM_1_ERROR)

@implement_for("gymnasium", None, "1.0.0")
def _set_seed_initial(self, seed: int) -> None: # noqa: F811
try:
self.reset(seed=seed)
Expand All @@ -1003,7 +1046,11 @@ def _reward_space(self, env):
if hasattr(env, "reward_space") and env.reward_space is not None:
return env.reward_space

@implement_for("gymnasium")
@implement_for("gymnasium", "1.0.0")
def _reward_space(self, env): # noqa: F811
raise ImportError(GYMNASIUM_1_ERROR)

@implement_for("gymnasium", None, "1.0.0")
def _reward_space(self, env): # noqa: F811
env = env.unwrapped
if hasattr(env, "reward_space") and env.reward_space is not None:
Expand Down Expand Up @@ -1397,7 +1444,14 @@ def _set_gym_args( # noqa: F811
) -> None:
kwargs.setdefault("disable_env_checker", True)

@implement_for("gymnasium")
@implement_for("gymnasium", "1.0.0")
def _set_gym_args( # noqa: F811
self,
kwargs,
) -> None:
raise ImportError(GYMNASIUM_1_ERROR)

@implement_for("gymnasium", None, "1.0.0")
def _set_gym_args( # noqa: F811
self,
kwargs,
Expand Down
Loading