Skip to content

Commit

Permalink
[Performance] Faster update_ (#705)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Mar 24, 2024
1 parent 07b9884 commit 059f539
Show file tree
Hide file tree
Showing 5 changed files with 936 additions and 83 deletions.
3 changes: 3 additions & 0 deletions tensordict/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1413,6 +1413,7 @@ def _apply_nest(
nested_keys: bool = False,
prefix: tuple = (),
filter_empty: bool | None = None,
is_leaf: Callable | None = None,
**constructor_kwargs,
) -> T | None:
if inplace and any(
Expand All @@ -1438,6 +1439,7 @@ def _apply_nest(
prefix=prefix,
inplace=inplace,
filter_empty=filter_empty,
is_leaf=is_leaf,
**constructor_kwargs,
)

Expand All @@ -1455,6 +1457,7 @@ def _apply_nest(
prefix=prefix, # + (i,),
inplace=inplace,
filter_empty=filter_empty,
is_leaf=is_leaf,
)
for i, (td, *oth) in enumerate(zip(self.tensordicts, *others))
]
Expand Down
7 changes: 6 additions & 1 deletion tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,6 +661,7 @@ def _apply_nest(
nested_keys: bool = False,
prefix: tuple = (),
filter_empty: bool | None = None,
is_leaf: Callable = None,
**constructor_kwargs,
) -> T | None:
if inplace:
Expand Down Expand Up @@ -696,10 +697,13 @@ def make_result():
is_locked = False

any_set = False
if is_leaf is None:
is_leaf = _default_is_leaf

for key, item in self.items():
if (
not call_on_nested
and _is_tensor_collection(item.__class__)
and not is_leaf(item.__class__)
# and not is_non_tensor(item)
):
if default is not NO_DEFAULT:
Expand All @@ -725,6 +729,7 @@ def make_result():
default=default,
prefix=prefix + (key,),
filter_empty=filter_empty,
is_leaf=is_leaf,
**constructor_kwargs,
)
else:
Expand Down
Loading

0 comments on commit 059f539

Please sign in to comment.