Skip to content

Commit

Permalink
amend
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jan 19, 2024
1 parent ec3ed57 commit 95e2890
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 19 deletions.
19 changes: 14 additions & 5 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -9796,7 +9796,12 @@ def __init__(self, *args, **kwargs):
other_reward=CompositeSpec(shape=self.batch_size),
shape=self.batch_size,
)
self.state_spec = CompositeSpec(state=CompositeSpec(sub=CompositeSpec(shape=self.batch_size),shape=self.batch_size),shape=self.batch_size)
self.state_spec = CompositeSpec(
state=CompositeSpec(
sub=CompositeSpec(shape=self.batch_size), shape=self.batch_size
),
shape=self.batch_size,
)

def _reset(self, tensordict):
return self.observation_spec.rand().update(self.full_done_spec.zero())
Expand Down Expand Up @@ -9830,11 +9835,15 @@ def test_parallel_trans_env_check(self):
env.close()

def test_trans_serial_env_check(self):
with pytest.raises(RuntimeError, match="The environment passed to SerialEnv has empty specs"):
with pytest.raises(
RuntimeError, match="The environment passed to SerialEnv has empty specs"
):
env = TransformedEnv(SerialEnv(2, self.DummyEnv), RemoveEmptySpecs())

def test_trans_parallel_env_check(self):
with pytest.raises(RuntimeError, match="The environment passed to ParallelEnv has empty specs"):
with pytest.raises(
RuntimeError, match="The environment passed to ParallelEnv has empty specs"
):
env = TransformedEnv(ParallelEnv(2, self.DummyEnv), RemoveEmptySpecs())

def test_transform_no_env(self):
Expand Down Expand Up @@ -9879,10 +9888,9 @@ def test_transform_rb(self, rbclass):
rb.extend(td)
td = rb.sample(1)
if "index" in td.keys():
del td['index']
del td["index"]
assert td.is_empty()


def test_transform_inverse(self):
td = TensorDict({"a": {"b": {"c": {}}}}, [])
assert not td.is_empty()
Expand All @@ -9893,6 +9901,7 @@ def test_transform_inverse(self):
td2 = env.transform.inv(TensorDict({}, []))
assert ("state", "sub") in td2.keys(True)


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
1 change: 1 addition & 0 deletions torchrl/data/tensor_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@

NO_DEFAULT = object()


def _default_dtype_and_device(
dtype: Union[None, torch.dtype],
device: Union[None, str, int, torch.device],
Expand Down
15 changes: 12 additions & 3 deletions torchrl/envs/batched_envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
from tensordict.tensordict import LazyStackedTensorDict, TensorDictBase
from torch import multiprocessing as mp
from torchrl._utils import _check_for_faulty_process, _ProcessNoWarn, VERBOSE
from torchrl.data.utils import CloudpickleWrapper, contains_lazy_spec, DEVICE_TYPING
from torchrl.data.tensor_specs import CompositeSpec
from torchrl.data.utils import CloudpickleWrapper, contains_lazy_spec, DEVICE_TYPING
from torchrl.envs.common import EnvBase
from torchrl.envs.env_creator import get_env_metadata

Expand Down Expand Up @@ -275,9 +275,18 @@ def _get_in_keys_to_exclude(self, tensordict):
def _set_properties(self):

cls = type(self)

def _check_for_empty_spec(specs: CompositeSpec):
for subspec in ("full_state_spec", "full_action_spec", "full_done_spec", "full_reward_spec", "full_observation_spec"):
for key, spec in reversed(list(specs.get(subspec, default=CompositeSpec()).items(True))):
for subspec in (
"full_state_spec",
"full_action_spec",
"full_done_spec",
"full_reward_spec",
"full_observation_spec",
):
for key, spec in reversed(
list(specs.get(subspec, default=CompositeSpec()).items(True))
):
if isinstance(spec, CompositeSpec) and spec.is_empty():
raise RuntimeError(
f"The environment passed to {cls.__name__} has empty specs in {key}. Consider using "
Expand Down
7 changes: 7 additions & 0 deletions torchrl/envs/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1493,6 +1493,13 @@ def register_gym(
the observation keys.
This arg can be passed during a call to :func:`~gym.make` (see
example below).
.. warning::
It may be the case that using ``info_keys`` makes a spec empty
because the content has been moved to the info dictionary.
Gym does not like empty ``Dict`` in the specs, so this empty
content should be removed with :class:`~torchrl.envs.transforms.RemoveEmptySpecs`.
backend (str, optional): the backend. Can be either `"gym"` or `"gymnasium"`
or any other backend compatible with :class:`~torchrl.envs.libs.gym.set_gym_backend`.
to_numpy (bool, optional): if ``True``, the result of calls to `step` and
Expand Down
48 changes: 37 additions & 11 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6947,7 +6947,9 @@ class RemoveEmptySpecs(Transform):
is_shared=False)
check_env_specs(env)
"""

_has_empty_input = True

@staticmethod
def _sorter(key_val):
key, _ = key_val
Expand All @@ -6960,15 +6962,21 @@ def transform_output_spec(self, output_spec: CompositeSpec) -> CompositeSpec:
full_reward_spec = output_spec["full_reward_spec"]
full_observation_spec = output_spec["full_observation_spec"]
# we reverse things to make sure we delete things from the back
for key, spec in reversed(sorted(full_done_spec.items(True), key=self._sorter)):
for key, spec in sorted(
full_done_spec.items(True), key=self._sorter, reverse=True
):
if isinstance(spec, CompositeSpec) and spec.is_empty():
del full_done_spec[key]

for key, spec in reversed(sorted(full_observation_spec.items(True), key=self._sorter)):
for key, spec in sorted(
full_observation_spec.items(True), key=self._sorter, reverse=True
):
if isinstance(spec, CompositeSpec) and spec.is_empty():
del full_observation_spec[key]

for key, spec in reversed(sorted(full_reward_spec.items(True), key=self._sorter)):
for key, spec in sorted(
full_reward_spec.items(True), key=self._sorter, reverse=True
):
if isinstance(spec, CompositeSpec) and spec.is_empty():
del full_reward_spec[key]
return output_spec
Expand All @@ -6979,38 +6987,56 @@ def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec:
# we reverse things to make sure we delete things from the back

self._has_empty_input = False
for key, spec in reversed(sorted(full_action_spec.items(True), key=self._sorter)):
for key, spec in sorted(
full_action_spec.items(True), key=self._sorter, reverse=True
):
if isinstance(spec, CompositeSpec) and spec.is_empty():
self._has_empty_input = True
del full_action_spec[key]

for key, spec in reversed(sorted(full_state_spec.items(True), key=self._sorter)):
for key, spec in sorted(
full_state_spec.items(True), key=self._sorter, reverse=True
):
if isinstance(spec, CompositeSpec) and spec.is_empty():
self._has_empty_input = True
del full_state_spec[key]
return input_spec

def _inv_call(self, tensordict: TensorDictBase) -> TensorDictBase:
if self._has_empty_input:
input_spec = getattr(self.parent, 'input_spec', None)
input_spec = getattr(self.parent, "input_spec", None)
if input_spec is None:
return tensordict

full_action_spec = input_spec["full_action_spec"]
full_state_spec = input_spec["full_state_spec"]
# we reverse things to make sure we delete things from the back

for key, spec in reversed(sorted(full_action_spec.items(True), key=self._sorter)):
if isinstance(spec, CompositeSpec) and spec.is_empty() and key not in tensordict.keys(True):
for key, spec in sorted(
full_action_spec.items(True), key=self._sorter, reverse=True
):
if (
isinstance(spec, CompositeSpec)
and spec.is_empty()
and key not in tensordict.keys(True)
):
tensordict.create_nested(key)

for key, spec in reversed(sorted(full_state_spec.items(True), key=self._sorter)):
if isinstance(spec, CompositeSpec) and spec.is_empty() and key not in tensordict.keys(True):
for key, spec in sorted(
full_state_spec.items(True), key=self._sorter, reverse=True
):
if (
isinstance(spec, CompositeSpec)
and spec.is_empty()
and key not in tensordict.keys(True)
):
tensordict.create_nested(key)
return tensordict

def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
for key, value in reversed(sorted(tensordict.items(True), key=self._sorter)):
for key, value in sorted(
tensordict.items(True), key=self._sorter, reverse=True
):
if (
is_tensor_collection(value)
and not isinstance(value, NonTensorData)
Expand Down

0 comments on commit 95e2890

Please sign in to comment.