Skip to content

Commit

Permalink
minor
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Nov 22, 2023
1 parent 2ea264b commit 51ac5ec
Showing 1 changed file with 8 additions and 9 deletions.
17 changes: 8 additions & 9 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,6 @@ def to_module(self, module, return_swap: bool = True, swap_dest=None, memo=None)
swap = swap.to(device=local_out.device)

if return_swap:
assert local_out is not None, key
swap._set_str(key, local_out, inplace=False, validated=True)
return swap

Expand Down Expand Up @@ -1242,12 +1241,13 @@ def _set_str(
inplace: bool,
validated: bool,
) -> T:
best_attempt = inplace is BEST_ATTEMPT_INPLACE
inplace = self._convert_inplace(inplace, key)
if inplace is not False:
best_attempt = inplace is BEST_ATTEMPT_INPLACE
inplace = self._convert_inplace(inplace, key)
if not validated:
value = self._validate_value(value, check_shape=True)
if not inplace:
if self.is_locked:
if self._is_locked:
raise RuntimeError(_LOCK_ERROR)
self._tensordict[key] = value
else:
Expand Down Expand Up @@ -1703,14 +1703,13 @@ def contiguous(self) -> T:
def empty(self, recurse=False) -> T:
if not recurse:
return TensorDict(
device=self.device,
batch_size=self.batch_size,
device=self._device,
batch_size=self._batch_size,
source={},
# names=self.names if self._has_names() else None,
names=self._td_dim_names,
_run_checks=False,
_is_memmap=self._is_memmap,
_is_shared=self._is_shared,
_is_memmap=False,
_is_shared=False,
)
return super().empty(recurse=recurse)

Expand Down

0 comments on commit 51ac5ec

Please sign in to comment.