diff --git a/test/test_env.py b/test/test_env.py index fc566749b8c..cfe4f04015c 100644 --- a/test/test_env.py +++ b/test/test_env.py @@ -1242,7 +1242,7 @@ def test_nested( obs_key = "state" if nested_obs: obs_key = nested_key + (obs_key,) - other_key = "beatles" + other_key = "other" if nested_other: other_key = nested_key + (other_key,) @@ -1310,7 +1310,7 @@ def test_nested( else: assert done_key not in td_nested_keys if keep_other: - assert other_key in td_nested_keys + assert other_key in td_nested_keys, other_key assert (td[other_key] == 0).all() else: assert other_key not in td_nested_keys diff --git a/torchrl/envs/batched_envs.py b/torchrl/envs/batched_envs.py index 03262fcdd1d..dd96e2a7a5c 100644 --- a/torchrl/envs/batched_envs.py +++ b/torchrl/envs/batched_envs.py @@ -474,6 +474,11 @@ def _create_td(self) -> None: # safe since the td is locked. self._cache_shared_keys = set(self.shared_tensordict_parent.keys(True, True)) + self._shared_tensordict_parent_next = self.shared_tensordict_parent.get("next") + self._shared_tensordict_parent_root = self.shared_tensordict_parent.exclude( + "next", *self.reset_keys + ) + def _start_workers(self) -> None: """Starts the various envs.""" raise NotImplementedError @@ -862,17 +867,17 @@ def step_and_maybe_reset( # and this transform overrides an observation key (eg, CatFrames) # the shape, dtype or device may not necessarily match and writing # the value in-place will fail. - for key in tensordict.keys(True, True): + for key in self._env_input_keys: + self.shared_tensordict_parent.set_(key, tensordict.get(key)) + next_td = tensordict.get("next", None) + if next_td is not None: # we copy the input keys as well as the keys in the 'next' td, if any # as this mechanism can be used by a policy to set anticipatively the # keys of the next call (eg, with recurrent nets) - if key in self._env_input_keys or ( - isinstance(key, tuple) - and key[0] == "next" - and key in self.shared_tensordict_parent.keys(True, True) - ): - val = tensordict.get(key) - self.shared_tensordict_parent.set_(key, val) + for key in next_td.keys(True, True): + key = unravel_key(("next", key)) + if key in self.shared_tensordict_parent.keys(True, True): + self.shared_tensordict_parent.set_(key, next_td.get(key[1:])) else: self.shared_tensordict_parent.update_( tensordict.select(*self._env_input_keys, "next", strict=False) @@ -887,8 +892,8 @@ def step_and_maybe_reset( # We must pass a clone of the tensordict, as the values of this tensordict # will be modified in-place at further steps - next_td = self.shared_tensordict_parent.get("next") - tensordict_ = self.shared_tensordict_parent.exclude("next", *self.reset_keys) + next_td = self._shared_tensordict_parent_next + tensordict_ = self._shared_tensordict_parent_root device = self.device if self.shared_tensordict_parent.device == device: next_td = next_td.clone() @@ -1201,12 +1206,12 @@ def _run_worker_pipe_shared_mem( i = 0 next_shared_tensordict = shared_tensordict.get("next") root_shared_tensordict = shared_tensordict.exclude("next") - shared_tensordict = shared_tensordict.clone(False) - if not (shared_tensordict.is_shared() or shared_tensordict.is_memmap()): raise RuntimeError( "tensordict must be placed in shared memory (share_memory_() or memmap_())" ) + shared_tensordict = shared_tensordict.clone(False) + initialized = True elif cmd == "reset": diff --git a/torchrl/envs/utils.py b/torchrl/envs/utils.py index 6605301ed3b..f505def52af 100644 --- a/torchrl/envs/utils.py +++ b/torchrl/envs/utils.py @@ -187,7 +187,7 @@ def step_mdp( next_tensordicts = next_tensordict.unbind(tensordict.stack_dim) else: next_tensordicts = [None] * len(tensordict.tensordicts) - out = torch.stack( + out = LazyStackedTensorDict.lazy_stack( [ step_mdp( td,