From f43262a2e9faf562572d4b15203e4fe5897f07e1 Mon Sep 17 00:00:00 2001 From: vmoens Date: Mon, 22 Jan 2024 16:44:21 +0000 Subject: [PATCH] init --- torchrl/envs/transforms/transforms.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/torchrl/envs/transforms/transforms.py b/torchrl/envs/transforms/transforms.py index ab95cd8352b..ed8be751474 100644 --- a/torchrl/envs/transforms/transforms.py +++ b/torchrl/envs/transforms/transforms.py @@ -5151,6 +5151,8 @@ def _reset( step_count = tensordict.get(step_count_key, default=None) if step_count is None: step_count = self.container.observation_spec[step_count_key].zero() + if step_count.device != reset.device: + step_count = step_count.to(reset.device, non_blocking=True) # zero the step count if reset is needed step_count = torch.where(~expand_as_right(reset, step_count), step_count, 0) @@ -6413,7 +6415,7 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase: raise ValueError( self.SPEC_TYPE_ERROR.format(self.ACCEPTED_SPECS, type(action_spec)) ) - action_spec.update_mask(mask) + action_spec.update_mask(mask.to(action_spec.device)) return tensordict def _reset( @@ -6424,7 +6426,10 @@ def _reset( raise ValueError( self.SPEC_TYPE_ERROR.format(self.ACCEPTED_SPECS, type(action_spec)) ) - action_spec.update_mask(tensordict.get(self.in_keys[1], None)) + mask = tensordict.get(self.in_keys[1], None) + if mask is not None: + mask = mask.to(action_spec.device) + action_spec.update_mask(mask) # TODO: Check that this makes sense with _set_missing_tolerance(self, True):