Skip to content

Commit

Permalink
[BugFix,Feature] Optional non_blocking in set, to_module and update (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Mar 22, 2024
1 parent b4c91e8 commit 2dc0285
Show file tree
Hide file tree
Showing 9 changed files with 467 additions and 123 deletions.
119 changes: 98 additions & 21 deletions tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,7 @@ def _set_str(
inplace: bool,
validated: bool,
ignore_lock: bool = False,
non_blocking: bool = False,
) -> T:
try:
inplace = self._convert_inplace(inplace, key)
Expand All @@ -446,7 +447,12 @@ def _set_str(
values = value.unbind(self.stack_dim)
for tensordict, item in zip(self.tensordicts, values):
tensordict._set_str(
key, item, inplace=inplace, validated=validated, ignore_lock=ignore_lock
key,
item,
inplace=inplace,
validated=validated,
ignore_lock=ignore_lock,
non_blocking=non_blocking,
)
return self

Expand All @@ -457,9 +463,16 @@ def _set_tuple(
*,
inplace: bool,
validated: bool,
non_blocking: bool = False,
) -> T:
if len(key) == 1:
return self._set_str(key[0], value, inplace=inplace, validated=validated)
return self._set_str(
key[0],
value,
inplace=inplace,
validated=validated,
non_blocking=non_blocking,
)
# if inplace is not False: # inplace could be None
# # we don't want to end up in the situation where one tensordict has
# # inplace=True and another one inplace=False because inplace was loose.
Expand All @@ -482,7 +495,13 @@ def _set_tuple(
value = self.hook_in(value)
values = value.unbind(self.stack_dim)
for tensordict, item in zip(self.tensordicts, values):
tensordict._set_tuple(key, item, inplace=inplace, validated=validated)
tensordict._set_tuple(
key,
item,
inplace=inplace,
validated=validated,
non_blocking=non_blocking,
)
return self

def _split_index(self, index):
Expand Down Expand Up @@ -669,7 +688,7 @@ def index_tuple_index(i, convert=False):
"num_squash": num_squash,
}

def _set_at_str(self, key, value, index, *, validated):
def _set_at_str(self, key, value, index, *, validated, non_blocking: bool):
if not validated:
value = self._validate_value(value, check_shape=False)
validated = True
Expand All @@ -686,7 +705,9 @@ def _set_at_str(self, key, value, index, *, validated):
if isinteger:
# this will break if the index along the stack dim is [0] or :1 or smth
for i, _idx in converted_idx.items():
self.tensordicts[i]._set_at_str(key, value, _idx, validated=validated)
self.tensordicts[i]._set_at_str(
key, value, _idx, validated=validated, non_blocking=non_blocking
)
return self
if is_nd_tensor:
unbind_dim = self.stack_dim - num_single + num_none - num_squash
Expand All @@ -700,7 +721,11 @@ def set_at_str(converted_idx):
_value = value_unbind[i]
stack_idx, idx = item
self.tensordicts[stack_idx]._set_at_str(
key, _value, idx, validated=validated
key,
_value,
idx,
validated=validated,
non_blocking=non_blocking,
)

set_at_str(converted_idx)
Expand All @@ -712,7 +737,9 @@ def set_at_str(converted_idx):
converted_idx.items(),
value_unbind,
):
self.tensordicts[i]._set_at_str(key, _value, _idx, validated=validated)
self.tensordicts[i]._set_at_str(
key, _value, _idx, validated=validated, non_blocking=non_blocking
)
else:
# we must split, not unbind
mask_unbind = split_index["individual_masks"]
Expand All @@ -728,16 +755,28 @@ def set_at_str(converted_idx):
):
if mask.any():
self.tensordicts[i]._set_at_str(
key, _value, _idx, validated=validated
key,
_value,
_idx,
validated=validated,
non_blocking=non_blocking,
)
else:
for (i, _idx), _value in zip(converted_idx.items(), value_unbind):
self_idx = (slice(None),) * split_index["mask_loc"] + (i,)
self[self_idx]._set_at_str(key, _value, _idx, validated=validated)
self[self_idx]._set_at_str(
key,
_value,
_idx,
validated=validated,
non_blocking=non_blocking,
)

def _set_at_tuple(self, key, value, idx, *, validated):
def _set_at_tuple(self, key, value, idx, *, validated, non_blocking: bool):
if len(key) == 1:
return self._set_at_str(key[0], value, idx, validated=validated)
return self._set_at_str(
key[0], value, idx, validated=validated, non_blocking=non_blocking
)
# get the "last" tds
tds = []
for td in self.tensordicts:
Expand All @@ -756,7 +795,7 @@ def _set_at_tuple(self, key, value, idx, *, validated):
value = self.hook_in(value)
item = td._get_str(key, NO_DEFAULT)
item[idx] = value
td._set_str(key, item, inplace=True, validated=True)
td._set_str(key, item, inplace=True, validated=True, non_blocking=non_blocking)
return self

def _legacy_unsqueeze(self, dim: int) -> T:
Expand Down Expand Up @@ -1522,6 +1561,7 @@ def __setitem__(self, index: IndexType, value: T) -> T:
if isinstance(self, _SubTensorDict)
else False,
validated=False,
non_blocking=False,
)
return

Expand Down Expand Up @@ -2089,6 +2129,7 @@ def update(
clone: bool = False,
*,
keys_to_update: Sequence[NestedKey] | None = None,
non_blocking: bool = False,
**kwargs: Any,
) -> T:
# This implementation of update is compatible with exclusive keys
Expand Down Expand Up @@ -2120,7 +2161,11 @@ def update(
self.tensordicts, input_dict_or_td.tensordicts
):
td_dest.update(
td_source, clone=clone, keys_to_update=keys_to_update, **kwargs
td_source,
clone=clone,
keys_to_update=keys_to_update,
non_blocking=non_blocking,
**kwargs,
)
return self

Expand Down Expand Up @@ -2161,6 +2206,8 @@ def update_(
self,
input_dict_or_td: dict[str, CompatibleType] | TensorDictBase,
clone: bool = False,
*,
non_blocking: bool = False,
**kwargs: Any,
) -> T:
if input_dict_or_td is self:
Expand All @@ -2179,14 +2226,16 @@ def update_(
for td_dest, td_source in zip(
self.tensordicts, input_dict_or_td.unbind(self.stack_dim)
):
td_dest.update_(td_source, clone=clone, **kwargs)
td_dest.update_(td_source, clone=clone, non_blocking=non_blocking, **kwargs)
return self

def update_at_(
self,
input_dict_or_td: dict[str, CompatibleType] | TensorDictBase,
index: IndexType,
clone: bool = False,
*,
non_blocking: bool = False,
) -> T:
if not _is_tensor_collection(type(input_dict_or_td)):
input_dict_or_td = TensorDict.from_dict(
Expand All @@ -2202,6 +2251,7 @@ def update_at_(
self.tensordicts[i].update_at_(
input_dict_or_td,
_idx,
non_blocking=non_blocking,
)
return self
unbind_dim = self.stack_dim - num_single
Expand All @@ -2212,6 +2262,7 @@ def update_at_(
self.tensordicts[i].update_at_(
_value,
_idx,
non_blocking=non_blocking,
)
return self

Expand Down Expand Up @@ -2694,20 +2745,40 @@ def _transform_value(self, item):
return getattr(item, self.custom_op)(**self._update_custom_op_kwargs(item))

def _set_str(
self, key, value, *, inplace: bool, validated: bool, ignore_lock: bool = False
self,
key,
value,
*,
inplace: bool,
validated: bool,
ignore_lock: bool = False,
non_blocking: bool = False,
):
if not validated:
value = self._validate_value(value, check_shape=True)
validated = True
value = getattr(value, self.inv_op)(**self._update_inv_op_kwargs(value))
self._source._set_str(
key, value, inplace=inplace, validated=validated, ignore_lock=ignore_lock
key,
value,
inplace=inplace,
validated=validated,
ignore_lock=ignore_lock,
non_blocking=non_blocking,
)
return self

def _set_tuple(self, key, value, *, inplace: bool, validated: bool):
def _set_tuple(
self, key, value, *, inplace: bool, validated: bool, non_blocking: bool
):
if len(key) == 1:
return self._set_str(key[0], value, inplace=inplace, validated=validated)
return self._set_str(
key[0],
value,
inplace=inplace,
validated=validated,
non_blocking=non_blocking,
)
source = self._source._get_str(key[0], None)
if source is None:
source = self._source._create_nested_str(key[0])
Expand All @@ -2718,10 +2789,16 @@ def _set_tuple(self, key, value, *, inplace: bool, validated: bool):
custom_op_kwargs=self._update_custom_op_kwargs(source),
inv_op_kwargs=self._update_inv_op_kwargs(source),
)
nested._set_tuple(key[1:], value, inplace=inplace, validated=validated)
nested._set_tuple(
key[1:],
value,
inplace=inplace,
validated=validated,
non_blocking=non_blocking,
)
return self

def _set_at_str(self, key, value, idx, *, validated):
def _set_at_str(self, key, value, idx, *, validated, non_blocking: bool):
transformed_tensor, original_tensor = self._get_str(
key, NO_DEFAULT
), self._source._get_str(key, NO_DEFAULT)
Expand All @@ -2735,7 +2812,7 @@ def _set_at_str(self, key, value, idx, *, validated):
transformed_tensor[idx] = value
return self

def _set_at_tuple(self, key, value, idx, *, validated):
def _set_at_tuple(self, key, value, idx, *, validated, non_blocking: bool):
transformed_tensor, original_tensor = self._get_tuple(
key, NO_DEFAULT
), self._source._get_tuple(key, NO_DEFAULT)
Expand Down
Loading

0 comments on commit 2dc0285

Please sign in to comment.