Skip to content

Commit

Permalink
[Feature] Improve info_dict reader (#1809)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jan 18, 2024
1 parent 1bd28da commit eb61100
Show file tree
Hide file tree
Showing 4 changed files with 335 additions and 62 deletions.
117 changes: 88 additions & 29 deletions test/test_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@
from torchrl.data.tensor_specs import (
CompositeSpec,
DiscreteTensorSpec,
OneHotDiscreteTensorSpec,
UnboundedContinuousTensorSpec,
)
from torchrl.envs import CatTensors, DoubleToFloat, EnvCreator, ParallelEnv, SerialEnv
Expand Down Expand Up @@ -1608,42 +1607,102 @@ def test_batch_unlocked_with_batch_size(device):
env.step(td_expanded)


@pytest.mark.skipif(not _has_gym, reason="no gym")
@pytest.mark.skipif(
gym_version is None or gym_version < version.parse("0.20.0"),
reason="older versions of half-cheetah do not have 'x_position' info key.",
)
@pytest.mark.parametrize("device", get_default_devices())
def test_info_dict_reader(device, seed=0):
try:
import gymnasium as gym
except ModuleNotFoundError:
import gym
class TestInfoDict:
@pytest.mark.skipif(not _has_gym, reason="no gym")
@pytest.mark.skipif(
gym_version is None or gym_version < version.parse("0.20.0"),
reason="older versions of half-cheetah do not have 'x_position' info key.",
)
@pytest.mark.parametrize("device", get_default_devices())
def test_info_dict_reader(self, device, seed=0):
try:
import gymnasium as gym
except ModuleNotFoundError:
import gym

env = GymWrapper(gym.make(HALFCHEETAH_VERSIONED), device=device)
env.set_info_dict_reader(default_info_dict_reader(["x_position"]))

assert "x_position" in env.observation_spec.keys()
assert isinstance(
env.observation_spec["x_position"], UnboundedContinuousTensorSpec
)

tensordict = env.reset()
tensordict = env.rand_step(tensordict)

env = GymWrapper(gym.make(HALFCHEETAH_VERSIONED), device=device)
env.set_info_dict_reader(default_info_dict_reader(["x_position"]))
assert env.observation_spec["x_position"].is_in(
tensordict[("next", "x_position")]
)

assert "x_position" in env.observation_spec.keys()
assert isinstance(env.observation_spec["x_position"], UnboundedContinuousTensorSpec)
for spec in (
{"x_position": UnboundedContinuousTensorSpec(10)},
None,
CompositeSpec(x_position=UnboundedContinuousTensorSpec(10), shape=[]),
[UnboundedContinuousTensorSpec(10)],
):
env2 = GymWrapper(gym.make("HalfCheetah-v4"))
env2.set_info_dict_reader(
default_info_dict_reader(["x_position"], spec=spec)
)

tensordict = env.reset()
tensordict = env.rand_step(tensordict)
tensordict2 = env2.reset()
tensordict2 = env2.rand_step(tensordict2)

assert env.observation_spec["x_position"].is_in(tensordict[("next", "x_position")])
assert env2.observation_spec["x_position"].is_in(
tensordict2[("next", "x_position")]
)

env2 = GymWrapper(gym.make("HalfCheetah-v4"))
env2.set_info_dict_reader(
default_info_dict_reader(
["x_position"], spec={"x_position": OneHotDiscreteTensorSpec(5)}
)
@pytest.mark.skipif(not _has_gym, reason="no gym")
@pytest.mark.skipif(
gym_version is None or gym_version < version.parse("0.20.0"),
reason="older versions of half-cheetah do not have 'x_position' info key.",
)
@pytest.mark.parametrize("device", get_default_devices())
def test_auto_register(self, device):
try:
import gymnasium as gym
except ModuleNotFoundError:
import gym

tensordict2 = env2.reset()
tensordict2 = env2.rand_step(tensordict2)
env = GymWrapper(gym.make(HALFCHEETAH_VERSIONED), device=device)
check_env_specs(env)
env.set_info_dict_reader()
with pytest.raises(
AssertionError, match="The keys of the specs and data do not match"
):
check_env_specs(env)

assert not env2.observation_spec["x_position"].is_in(
tensordict2[("next", "x_position")]
)
env = GymWrapper(gym.make(HALFCHEETAH_VERSIONED), device=device)
env = env.auto_register_info_dict()
check_env_specs(env)

# check that the env can be executed in parallel
penv = ParallelEnv(
2,
lambda: GymWrapper(
gym.make(HALFCHEETAH_VERSIONED), device=device
).auto_register_info_dict(),
)
senv = ParallelEnv(
2,
lambda: GymWrapper(
gym.make(HALFCHEETAH_VERSIONED), device=device
).auto_register_info_dict(),
)
try:
torch.manual_seed(0)
penv.set_seed(0)
rolp = penv.rollout(10)
torch.manual_seed(0)
senv.set_seed(0)
rols = senv.rollout(10)
assert_allclose_td(rolp, rols)
finally:
penv.close()
del penv
senv.close()
del senv


def test_make_spec_from_td():
Expand Down
106 changes: 105 additions & 1 deletion torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,7 +744,111 @@ class ParallelEnv(_BatchedEnv):
__doc__ += _BatchedEnv.__doc__
__doc__ += """
.. note::
.. warning::
TorchRL's ParallelEnv is quite stringent when it comes to env specs, since
these are used to build shared memory buffers for inter-process communication.
As such, we encourage users to first run a check of the env specs with
:func:`~torchrl.envs.utils.check_env_specs`:
>>> from torchrl.envs import check_env_specs
>>> env = make_env()
>>> check_env_specs(env) # if this passes without error you're good to go!
>>> penv = ParallelEnv(2, make_env)
In particular, gym-like envs with info-dict readers may be difficult to
share across processes if the spec is not properly set, which is hard to
do automatically. Check :meth:`~torchrl.envs.GymLikeEnv.set_info_dict_reader`
for more information. Here is a short example:
>>> from torchrl.envs import GymEnv, set_gym_backend, check_env_specs, TransformedEnv, TensorDictPrimer
>>> import torch
>>> env = GymEnv("HalfCheetah-v4")
>>> env.rollout(3) # no info registered, this env passes check_env_specs
TensorDict(
fields={
action: Tensor(shape=torch.Size([10, 6]), device=cpu, dtype=torch.float32, is_shared=False),
done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
next: TensorDict(
fields={
done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
observation: Tensor(shape=torch.Size([10, 17]), device=cpu, dtype=torch.float64, is_shared=False),
reward: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([10]),
device=cpu,
is_shared=False),
observation: Tensor(shape=torch.Size([10, 17]), device=cpu, dtype=torch.float64, is_shared=False),
terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([10]),
device=cpu,
is_shared=False)
>>> check_env_specs(env) # succeeds!
>>> env.set_info_dict_reader() # sets the default info_dict reader
>>> env.rollout(10) # because the info_dict is empty at reset time, we're missing the root infos!
TensorDict(
fields={
action: Tensor(shape=torch.Size([10, 6]), device=cpu, dtype=torch.float32, is_shared=False),
done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
next: TensorDict(
fields={
done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
observation: Tensor(shape=torch.Size([10, 17]), device=cpu, dtype=torch.float64, is_shared=False),
reward: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
reward_ctrl: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float64, is_shared=False),
reward_run: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float64, is_shared=False),
terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
x_position: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float64, is_shared=False),
x_velocity: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float64, is_shared=False)},
batch_size=torch.Size([10]),
device=cpu,
is_shared=False),
observation: Tensor(shape=torch.Size([10, 17]), device=cpu, dtype=torch.float64, is_shared=False),
terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
batch_size=torch.Size([10]),
device=cpu,
is_shared=False)
>>> check_env_specs(env) # This check now fails! We should not use an env constructed like this in a parallel env
>>> # This ad-hoc fix registers the info-spec for reset. It is wrapped inside `env.auto_register_info_dict()`
>>> env_fixed = TransformedEnv(env, TensorDictPrimer(env.info_dict_reader[0].info_spec))
>>> env_fixed.rollout(10)
TensorDict(
fields={
action: Tensor(shape=torch.Size([10, 6]), device=cpu, dtype=torch.float32, is_shared=False),
done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
next: TensorDict(
fields={
done: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
observation: Tensor(shape=torch.Size([10, 17]), device=cpu, dtype=torch.float64, is_shared=False),
reward: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.float32, is_shared=False),
reward_ctrl: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float64, is_shared=False),
reward_run: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float64, is_shared=False),
terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
x_position: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float64, is_shared=False),
x_velocity: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float64, is_shared=False)},
batch_size=torch.Size([10]),
device=cpu,
is_shared=False),
observation: Tensor(shape=torch.Size([10, 17]), device=cpu, dtype=torch.float64, is_shared=False),
reward_ctrl: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float64, is_shared=False),
reward_run: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float64, is_shared=False),
terminated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
truncated: Tensor(shape=torch.Size([10, 1]), device=cpu, dtype=torch.bool, is_shared=False),
x_position: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float64, is_shared=False),
x_velocity: Tensor(shape=torch.Size([10]), device=cpu, dtype=torch.float64, is_shared=False)},
batch_size=torch.Size([10]),
device=cpu,
is_shared=False)
>>> check_env_specs(env_fixed) # Succeeds! This env can be used within a parallel env!
Related classes and methods: :meth:`~torchrl.envs.GymLikeEnv.auto_register_info_dict`
and :class:`~torchrl.envs.gym_like.default_info_dict_reader`.
.. warning::
The choice of the devices where ParallelEnv needs to be executed can
drastically influence its performance. The rule of thumbs is:
Expand Down
Loading

0 comments on commit eb61100

Please sign in to comment.