Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Jan 22, 2024
1 parent 3f04131 commit f43262a
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions torchrl/envs/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5151,6 +5151,8 @@ def _reset(
step_count = tensordict.get(step_count_key, default=None)
if step_count is None:
step_count = self.container.observation_spec[step_count_key].zero()
if step_count.device != reset.device:
step_count = step_count.to(reset.device, non_blocking=True)

# zero the step count if reset is needed
step_count = torch.where(~expand_as_right(reset, step_count), step_count, 0)
Expand Down Expand Up @@ -6413,7 +6415,7 @@ def _call(self, tensordict: TensorDictBase) -> TensorDictBase:
raise ValueError(
self.SPEC_TYPE_ERROR.format(self.ACCEPTED_SPECS, type(action_spec))
)
action_spec.update_mask(mask)
action_spec.update_mask(mask.to(action_spec.device))
return tensordict

def _reset(
Expand All @@ -6424,7 +6426,10 @@ def _reset(
raise ValueError(
self.SPEC_TYPE_ERROR.format(self.ACCEPTED_SPECS, type(action_spec))
)
action_spec.update_mask(tensordict.get(self.in_keys[1], None))
mask = tensordict.get(self.in_keys[1], None)
if mask is not None:
mask = mask.to(action_spec.device)
action_spec.update_mask(mask)

# TODO: Check that this makes sense
with _set_missing_tolerance(self, True):
Expand Down

0 comments on commit f43262a

Please sign in to comment.