Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Cat LazyStackedTensorDicts #499

Merged
merged 16 commits into from
Aug 2, 2023
159 changes: 137 additions & 22 deletions tensordict/tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def __bool__(self):
# distributed frameworks
DIST_SEPARATOR = ".-|-."
TD_HANDLED_FUNCTIONS: dict[Callable, Callable] = {}
LAZY_TD_HANDLED_FUNCTIONS: dict[Callable, Callable] = {}
CompatibleType = Union[
Tensor,
MemmapTensor,
Expand Down Expand Up @@ -4644,6 +4645,17 @@ def decorator(func: Callable) -> Callable:
return decorator


def implements_for_lazy_td(torch_function: Callable) -> Callable[[Callable], Callable]:
"""Register a torch function override for TensorDict."""

@functools.wraps(torch_function)
def decorator(func: Callable) -> Callable:
LAZY_TD_HANDLED_FUNCTIONS[torch_function] = func
return func

return decorator


# @implements_for_td(torch.testing.assert_allclose) TODO
def assert_allclose_td(
actual: TensorDictBase,
Expand All @@ -4668,24 +4680,36 @@ def assert_allclose_td(
)
keys = sorted(actual.keys(), key=str)
for key in keys:
input1 = actual.get(key)
input2 = expected.get(key)
if _is_tensor_collection(input1.__class__):
assert_allclose_td(input1, input2, rtol=rtol, atol=atol)
continue
shape1 = actual.get_item_shape(key)
shape2 = expected.get_item_shape(key)
if -1 in shape1 or -1 in shape2:
matteobettini marked this conversation as resolved.
Show resolved Hide resolved
if not (-1 in shape1 and -1 in shape2):
raise KeyError(
f"{key} corresponds to an heterogeneous entry only in one of the tensordicts provided"
)
matteobettini marked this conversation as resolved.
Show resolved Hide resolved
for sub_actual, sub_expected in zip(
actual.tensordicts, expected.tensordicts
):
assert_allclose_td(sub_actual, sub_expected, rtol=rtol, atol=atol)
else:
input1 = actual.get(key)
input2 = expected.get(key)
if _is_tensor_collection(input1.__class__):
assert_allclose_td(input1, input2, rtol=rtol, atol=atol)
continue

mse = (input1.to(torch.float) - input2.to(torch.float)).pow(2).sum()
mse = mse.div(input1.numel()).sqrt().item()

default_msg = f"key {key} does not match, got mse = {mse:4.4f}"
msg = "\t".join([default_msg, msg]) if len(msg) else default_msg
if isinstance(input1, MemmapTensor):
input1 = input1._tensor
if isinstance(input2, MemmapTensor):
input2 = input2._tensor
torch.testing.assert_close(
input1, input2, rtol=rtol, atol=atol, equal_nan=equal_nan, msg=msg
)
mse = (input1.to(torch.float) - input2.to(torch.float)).pow(2).sum()
mse = mse.div(input1.numel()).sqrt().item()

default_msg = f"key {key} does not match, got mse = {mse:4.4f}"
msg = "\t".join([default_msg, msg]) if len(msg) else default_msg
if isinstance(input1, MemmapTensor):
input1 = input1._tensor
if isinstance(input2, MemmapTensor):
input2 = input2._tensor
torch.testing.assert_close(
input1, input2, rtol=rtol, atol=atol, equal_nan=equal_nan, msg=msg
)
return True


Expand Down Expand Up @@ -4902,6 +4926,78 @@ def _cat(
return out


@implements_for_lazy_td(torch.cat)
def _lazy_cat(
list_of_tensordicts: Sequence[LazyStackedTensorDict],
dim: int = 0,
out: LazyStackedTensorDict | None = None,
) -> LazyStackedTensorDict:
if not list_of_tensordicts:
raise RuntimeError("list_of_tensordicts cannot be empty")

batch_size = list(list_of_tensordicts[0].batch_size)
if dim < 0:
dim = len(batch_size) + dim
if dim >= len(batch_size):
raise RuntimeError(
f"dim must be in the range 0 <= dim < len(batch_size), got dim"
f"={dim} and batch_size={batch_size}"
)
stack_dim = list_of_tensordicts[0].stack_dim
if any([(td.stack_dim != stack_dim) for td in list_of_tensordicts]):
matteobettini marked this conversation as resolved.
Show resolved Hide resolved
raise RuntimeError("cat lazy stacked tds must have same stack dim")

batch_size[dim] = sum([td.batch_size[dim] for td in list_of_tensordicts])
matteobettini marked this conversation as resolved.
Show resolved Hide resolved
batch_size = torch.Size(batch_size)

new_dim = dim
if dim > stack_dim:
new_dim = dim - 1

# check that all tensordict match
_check_keys(list_of_tensordicts, strict=True)
matteobettini marked this conversation as resolved.
Show resolved Hide resolved
if out is None:
out = []
if dim == stack_dim: # if dim is stack, just add all to same list
for lazy_td in list_of_tensordicts:
out += lazy_td.tensordicts
else:
for i in range(len(list_of_tensordicts[0].tensordicts)):
out.append(
torch.cat(
[lazy_td.tensordicts[i] for lazy_td in list_of_tensordicts],
new_dim,
)
)
return LazyStackedTensorDict(*out, stack_dim=stack_dim)
else:
if out.batch_size != batch_size:
raise RuntimeError(
"out.batch_size and cat batch size must match, "
f"got out.batch_size={out.batch_size} and batch_size"
f"={batch_size}"
)
if out.stack_dim != stack_dim:
raise RuntimeError(
"out.stack_dim and elements stack_sim did not match, "
f"got out.stack_dim={out.stack_dim} and stack_dim={stack_dim}"
)
matteobettini marked this conversation as resolved.
Show resolved Hide resolved

if dim == stack_dim:
td_list = []
for lazy_td in list_of_tensordicts:
td_list += lazy_td.tensordicts
out.tensordicts = td_list
matteobettini marked this conversation as resolved.
Show resolved Hide resolved
else:
for i in range(len(out.tensordicts)):
out.tensordicts[i] = torch.cat(
[lazy_td.tensordicts[i] for lazy_td in list_of_tensordicts],
new_dim,
)

return out


@implements_for_td(torch.stack)
def _stack(
list_of_tensordicts: Sequence[TensorDictBase],
Expand Down Expand Up @@ -5843,6 +5939,25 @@ class LazyStackedTensorDict(TensorDictBase):

"""

@classmethod
def __torch_function__(
cls,
func: Callable,
types: tuple[type, ...],
args: tuple[Any, ...] = (),
kwargs: dict[str, Any] | None = None,
) -> Callable:
if func in LAZY_TD_HANDLED_FUNCTIONS:
if kwargs is None:
kwargs = {}
if func not in LAZY_TD_HANDLED_FUNCTIONS or not all(
issubclass(t, (Tensor, TensorDictBase)) for t in types
):
return NotImplemented
return LAZY_TD_HANDLED_FUNCTIONS[func](*args, **kwargs)
else:
return super().__torch_function__(func, types, args, kwargs)

def __new__(cls, *args: Any, **kwargs: Any) -> LazyStackedTensorDict:
cls._td_dim_name = None
return super().__new__(cls, *args, _safe=False, _lazy=True, **kwargs)
Expand Down Expand Up @@ -6248,7 +6363,7 @@ def _set_at_str(self, key, value, index, *, validated):
is_nd_tensor = split_index.get("is_nd_tensor", False)
if isinteger:
# this will break if the index along the stack dim is [0] or :1 or smth
for (i, _idx) in converted_idx.items():
for i, _idx in converted_idx.items():
self.tensordicts[i]._set_at_str(key, value, _idx, validated=validated)
return self
if is_nd_tensor:
Expand Down Expand Up @@ -6874,7 +6989,7 @@ def __setitem__(self, index: IndexType, value: TensorDictBase) -> TensorDictBase
is_nd_tensor = split_index.get("is_nd_tensor", False)
if isinteger:
# this will break if the index along the stack dim is [0] or :1 or smth
for (i, _idx) in converted_idx.items():
for i, _idx in converted_idx.items():
self.tensordicts[i][_idx] = value
return self
if is_nd_tensor:
Expand Down Expand Up @@ -6946,7 +7061,7 @@ def __getitem__(self, index: IndexType) -> TensorDictBase:
out[-1] = out[-1].squeeze(cat_dim)
return torch.stack(out, cat_dim)
else:
for (i, _idx) in converted_idx.items():
for i, _idx in converted_idx.items():
self_idx = (slice(None),) * split_index["mask_loc"] + (i,)
out.append(self[self_idx][_idx])
return torch.cat(out, cat_dim)
Expand All @@ -6970,7 +7085,7 @@ def __getitem__(self, index: IndexType) -> TensorDictBase:
else:
out = []
new_stack_dim = self.stack_dim - num_single + num_none - num_squash
for (i, _idx) in converted_idx.items():
for i, _idx in converted_idx.items():
out.append(self.tensordicts[i][_idx])
out = torch.stack(out, new_stack_dim)
out._td_dim_name = self._td_dim_name
Expand Down Expand Up @@ -7377,7 +7492,7 @@ def update_at_(
isinteger = split_index["isinteger"]
if isinteger:
# this will break if the index along the stack dim is [0] or :1 or smth
for (i, _idx) in converted_idx.items():
for i, _idx in converted_idx.items():
self.tensordicts[i].update_at_(
input_dict_or_td,
_idx,
Expand Down
34 changes: 34 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -4064,6 +4064,40 @@ def nested_lazy_het_td(batch_size):
obs = obs.expand(batch_size)
return obs

@pytest.mark.parametrize("batch_size", [(), (2,), (1, 2)])
@pytest.mark.parametrize("cat_dim", [0, 1, 2])
def test_cat_lazy_stack(self, batch_size, cat_dim):
if cat_dim > len(batch_size):
return
td_lazy = self.nested_lazy_het_td(batch_size)["lazy"]

res = torch.cat([td_lazy], dim=cat_dim)
assert assert_allclose_td(res, td_lazy)
assert res is not td_lazy
td_lazy_clone = td_lazy.clone()
res = torch.cat([td_lazy_clone], dim=cat_dim, out=td_lazy)
assert res is td_lazy
assert assert_allclose_td(res, td_lazy_clone)

td_lazy_2 = td_lazy.clone()
td_lazy_2.apply_(lambda x: x + 1)

res = torch.cat([td_lazy, td_lazy_2], dim=cat_dim)
assert res.stack_dim == len(batch_size)
assert res.shape[cat_dim] == td_lazy.shape[cat_dim] + td_lazy_2.shape[cat_dim]
index = (slice(None),) * cat_dim + (slice(0, td_lazy.shape[cat_dim]),)
assert assert_allclose_td(res[index], td_lazy)
index = (slice(None),) * cat_dim + (slice(td_lazy.shape[cat_dim], None),)
assert assert_allclose_td(res[index], td_lazy_2)

res = torch.cat([td_lazy, td_lazy_2], dim=cat_dim)
assert res.stack_dim == len(batch_size)
assert res.shape[cat_dim] == td_lazy.shape[cat_dim] + td_lazy_2.shape[cat_dim]
index = (slice(None),) * cat_dim + (slice(0, td_lazy.shape[cat_dim]),)
assert assert_allclose_td(res[index], td_lazy)
index = (slice(None),) * cat_dim + (slice(td_lazy.shape[cat_dim], None),)
assert assert_allclose_td(res[index], td_lazy_2)

def recursively_check_key(self, td, value: int):
if isinstance(td, LazyStackedTensorDict):
for t in td.tensordicts:
Expand Down