From 11a28983834e734ed3763a1247a6292fd8c275bb Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 16 Jan 2024 09:56:54 +0000 Subject: [PATCH 1/3] init --- torchrl/envs/batched_envs.py | 27 ++++++++++++---------- torchrl/envs/utils.py | 45 ++++++++++++++++++++++-------------- 2 files changed, 43 insertions(+), 29 deletions(-) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 03262fcdd1d..4021b062409 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -474,6 +474,9 @@ def _create_td(self) -> None: # safe since the td is locked. self._cache_shared_keys = set(self.shared_tensordict_parent.keys(True, True)) + self._shared_tensordict_parent_next = self.shared_tensordict_parent.get("next") + self._shared_tensordict_parent_root = self.shared_tensordict_parent.exclude("next", *self.reset_keys) + def _start_workers(self) -> None: """Starts the various envs.""" raise NotImplementedError @@ -862,17 +865,17 @@ def step_and_maybe_reset( # and this transform overrides an observation key (eg, CatFrames) # the shape, dtype or device may not necessarily match and writing # the value in-place will fail. - for key in tensordict.keys(True, True): + for key in self._env_input_keys: + self.shared_tensordict_parent.set_(key, tensordict.get(key)) + next_td = tensordict.get("next", None) + if next_td is not None: # we copy the input keys as well as the keys in the 'next' td, if any # as this mechanism can be used by a policy to set anticipatively the # keys of the next call (eg, with recurrent nets) - if key in self._env_input_keys or ( - isinstance(key, tuple) - and key[0] == "next" - and key in self.shared_tensordict_parent.keys(True, True) - ): - val = tensordict.get(key) - self.shared_tensordict_parent.set_(key, val) + for key in next_td.keys(True, True): + key = unravel_key(("next", key)) + if key in self.shared_tensordict_parent.keys(True, True): + self.shared_tensordict_parent.set_(key, next_td.get(key[1:])) else: self.shared_tensordict_parent.update_( tensordict.select(*self._env_input_keys, "next", strict=False) @@ -887,8 +890,8 @@ def step_and_maybe_reset( # We must pass a clone of the tensordict, as the values of this tensordict # will be modified in-place at further steps - next_td = self.shared_tensordict_parent.get("next") - tensordict_ = self.shared_tensordict_parent.exclude("next", *self.reset_keys) + next_td = self._shared_tensordict_parent_next + tensordict_ = self._shared_tensordict_parent_root device = self.device if self.shared_tensordict_parent.device == device: next_td = next_td.clone() @@ -1201,12 +1204,12 @@ def _run_worker_pipe_shared_mem( i = 0 next_shared_tensordict = shared_tensordict.get("next") root_shared_tensordict = shared_tensordict.exclude("next") - shared_tensordict = shared_tensordict.clone(False) - if not (shared_tensordict.is_shared() or shared_tensordict.is_memmap()): raise RuntimeError( "tensordict must be placed in shared memory (share_memory_() or memmap_())" ) + shared_tensordict = shared_tensordict.clone(False) + initialized = True elif cmd == "reset": diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 6605301ed3b..89e5f4b9634 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -4,6 +4,7 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations +import collections import contextlib import importlib.util @@ -15,7 +16,8 @@ import torch -from tensordict import is_tensor_collection, TensorDictBase, unravel_key +from tensordict import is_tensor_collection, TensorDictBase, unravel_key, \ + TensorDict from tensordict.nn.probabilistic import ( # noqa # Note: the `set_interaction_mode` and their associated arg `default_interaction_mode` are being deprecated! # Please use the `set_/interaction_type` ones above with the InteractionType enum instead. @@ -187,7 +189,7 @@ def step_mdp( next_tensordicts = next_tensordict.unbind(tensordict.stack_dim) else: next_tensordicts = [None] * len(tensordict.tensordicts) - out = torch.stack( + out = LazyStackedTensorDict.lazy_stack( [ step_mdp( td, @@ -218,24 +220,33 @@ def step_mdp( excluded = set() if exclude_reward: - excluded = excluded.union(reward_keys) + excluded_reward_keys = [unravel_key(reward_key) for reward_key in reward_keys] + excluded_reward_keys += [unravel_key(("next", reward_key)) for reward_key in excluded_reward_keys] + excluded = excluded.union(excluded_reward_keys) if exclude_done: - excluded = excluded.union(done_keys) + excluded_done_keys = [unravel_key(done_key) for done_key in done_keys] + excluded_done_keys += [unravel_key(("next", done_key)) for done_key in excluded_done_keys] + excluded = excluded.union(excluded_done_keys) if exclude_action: + action_keys = map(unravel_key, action_keys) excluded = excluded.union(action_keys) - next_td = tensordict.get("next") - out = next_td.empty() - - total_key = () - if keep_other: - for key in tensordict.keys(): - if key != "next": - _set(tensordict, out, key, total_key, excluded) - elif not exclude_action: - for action_key in action_keys: - _set_single_key(tensordict, out, action_key) - for key in next_td.keys(): - _set(next_td, out, key, total_key, excluded) + keys = set(tensordict.keys(True, True)) + # remove excluded keys + keys = collections.deque(keys - excluded) + # make the map + keys_map = {} + for i in range(len(keys)): + key = keys.popleft() + if isinstance(key, str) or key[0] != "next": + if keep_other or key in action_keys: + keys_map[key] = key + else: + keys.append(key) + while len(keys): + key = keys.popleft() + keys_map[unravel_key(key[1:])] = key + out = tensordict.select(*keys_map.values()) + out.update(out.pop("next")) if next_tensordict is not None: return next_tensordict.update(out) else: From 044de306cd149b95e69504f23db1f2f2d1d972ea Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 16 Jan 2024 10:43:10 +0000 Subject: [PATCH 2/3] amend --- test/test_env.py | 4 ++-- torchrl/envs/utils.py | 28 +++++++++++----------------- 2 files changed, 13 insertions(+), 19 deletions(-) diff --git a/test/test_env.py b/test/test_env.py index fc566749b8c..cfe4f04015c 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -1242,7 +1242,7 @@ def test_nested( obs_key = "state" if nested_obs: obs_key = nested_key + (obs_key,) - other_key = "beatles" + other_key = "other" if nested_other: other_key = nested_key + (other_key,) @@ -1310,7 +1310,7 @@ def test_nested( else: assert done_key not in td_nested_keys if keep_other: - assert other_key in td_nested_keys + assert other_key in td_nested_keys, other_key assert (td[other_key] == 0).all() else: assert other_key not in td_nested_keys diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 89e5f4b9634..f1b23e1d4cb 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -73,7 +73,6 @@ class _classproperty(property): def __get__(self, cls, owner): return classmethod(self.fget).__get__(None, owner)() - def step_mdp( tensordict: TensorDictBase, next_tensordict: TensorDictBase = None, @@ -230,23 +229,18 @@ def step_mdp( if exclude_action: action_keys = map(unravel_key, action_keys) excluded = excluded.union(action_keys) - keys = set(tensordict.keys(True, True)) - # remove excluded keys - keys = collections.deque(keys - excluded) - # make the map - keys_map = {} - for i in range(len(keys)): - key = keys.popleft() - if isinstance(key, str) or key[0] != "next": - if keep_other or key in action_keys: - keys_map[key] = key + if keep_other: + out = tensordict.exclude(*excluded) + for key, val in out.pop("next").items(): + out._set_str(key, val, validated=True, inplace=False) + else: + if exclude_action: + out = tensordict.exclude(*excluded).get("next") else: - keys.append(key) - while len(keys): - key = keys.popleft() - keys_map[unravel_key(key[1:])] = key - out = tensordict.select(*keys_map.values()) - out.update(out.pop("next")) + out = tensordict.select(*action_keys, "next").exclude(*excluded) + for key, val in out.pop("next").items(): + out._set_str(key, val, validated=True, inplace=False) + if next_tensordict is not None: return next_tensordict.update(out) else: From 6ba3c85e49253eac1c5aff092fa44779afef59e7 Mon Sep 17 00:00:00 2001 From: vmoens Date: Tue, 16 Jan 2024 11:08:46 +0000 Subject: [PATCH 3/3] amend --- torchrl/envs/batched_envs.py | 4 +++- torchrl/envs/utils.py | 37 ++++++++++++++++-------------------- 2 files changed, 19 insertions(+), 22 deletions(-) diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 4021b062409..dd96e2a7a5c 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -475,7 +475,9 @@ def _create_td(self) -> None: self._cache_shared_keys = set(self.shared_tensordict_parent.keys(True, True)) self._shared_tensordict_parent_next = self.shared_tensordict_parent.get("next") - self._shared_tensordict_parent_root = self.shared_tensordict_parent.exclude("next", *self.reset_keys) + self._shared_tensordict_parent_root = self.shared_tensordict_parent.exclude( + "next", *self.reset_keys + ) def _start_workers(self) -> None: """Starts the various envs.""" diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index f1b23e1d4cb..f505def52af 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -4,7 +4,6 @@ # LICENSE file in the root directory of this source tree. from __future__ import annotations -import collections import contextlib import importlib.util @@ -16,8 +15,7 @@ import torch -from tensordict import is_tensor_collection, TensorDictBase, unravel_key, \ - TensorDict +from tensordict import is_tensor_collection, TensorDictBase, unravel_key from tensordict.nn.probabilistic import ( # noqa # Note: the `set_interaction_mode` and their associated arg `default_interaction_mode` are being deprecated! # Please use the `set_/interaction_type` ones above with the InteractionType enum instead. @@ -73,6 +71,7 @@ class _classproperty(property): def __get__(self, cls, owner): return classmethod(self.fget).__get__(None, owner)() + def step_mdp( tensordict: TensorDictBase, next_tensordict: TensorDictBase = None, @@ -219,28 +218,24 @@ def step_mdp( excluded = set() if exclude_reward: - excluded_reward_keys = [unravel_key(reward_key) for reward_key in reward_keys] - excluded_reward_keys += [unravel_key(("next", reward_key)) for reward_key in excluded_reward_keys] - excluded = excluded.union(excluded_reward_keys) + excluded = excluded.union(reward_keys) if exclude_done: - excluded_done_keys = [unravel_key(done_key) for done_key in done_keys] - excluded_done_keys += [unravel_key(("next", done_key)) for done_key in excluded_done_keys] - excluded = excluded.union(excluded_done_keys) + excluded = excluded.union(done_keys) if exclude_action: - action_keys = map(unravel_key, action_keys) excluded = excluded.union(action_keys) - if keep_other: - out = tensordict.exclude(*excluded) - for key, val in out.pop("next").items(): - out._set_str(key, val, validated=True, inplace=False) - else: - if exclude_action: - out = tensordict.exclude(*excluded).get("next") - else: - out = tensordict.select(*action_keys, "next").exclude(*excluded) - for key, val in out.pop("next").items(): - out._set_str(key, val, validated=True, inplace=False) + next_td = tensordict.get("next") + out = next_td.empty() + total_key = () + if keep_other: + for key in tensordict.keys(): + if key != "next": + _set(tensordict, out, key, total_key, excluded) + elif not exclude_action: + for action_key in action_keys: + _set_single_key(tensordict, out, action_key) + for key in next_td.keys(): + _set(next_td, out, key, total_key, excluded) if next_tensordict is not None: return next_tensordict.update(out) else: