From 86cca931b4bb8eb2bfab9ff43abffb5ac9816018 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 15 May 2024 11:53:55 +0100 Subject: [PATCH 1/4] init --- tensordict/base.py | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) diff --git a/tensordict/base.py b/tensordict/base.py index 2883a2beb..80f6c3582 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -3840,6 +3840,7 @@ def sorted_keys(self) -> list[NestedKey]: """ return sorted(self.keys()) + @as_decorator() def flatten(self, start_dim=0, end_dim=-1): """Flattens all the tensors of a tensordict. @@ -3871,6 +3872,8 @@ def flatten(self, start_dim=0, end_dim=-1): tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]) """ + if start_dim < 0: + start_dim = self.ndim + start_dim if end_dim < 0: end_dim = self.ndim + end_dim if end_dim < 0: @@ -3906,6 +3909,7 @@ def flatten(tensor): out.names = names return out + @as_decorator() def unflatten(self, dim, unflattened_size): """Unflattens a tensordict dim expanding it to a desired shape. @@ -6128,6 +6132,8 @@ def __exit__(self, exc_type, exc_val, exc_tb): _last_op = self._last_op_queue.pop() if _last_op is not None: last_op, (args, kwargs, out) = _last_op + # TODO: transpose, flatten etc. as decorator should lock the content to make sure that no key is + # added or deleted if last_op == self.__class__.lock_.__name__: return self.unlock_() elif last_op == self.__class__.unlock_.__name__: @@ -6135,6 +6141,34 @@ def __exit__(self, exc_type, exc_val, exc_tb): elif last_op == self.__class__.transpose.__name__: dim0, dim1 = args return out.update(self.transpose(dim0, dim1)) + elif last_op == self.__class__.flatten.__name__: + if len(args) == 2: + dim0, dim1 = args + elif len(args) == 1: + dim0 = args[0] + dim1 = kwargs.get("end_dim", -1) + else: + dim0 = kwargs.get("start_dim", 0) + dim1 = kwargs.get("end_dim", -1) + if dim1 < 0: + dim1 = out.ndim + dim1 + if dim0 < 0: + dim0 = out.ndim + dim0 + return out.update(self.unflatten(dim0, out.shape[dim0 : dim1 + 1])) + elif last_op == self.__class__.unflatten.__name__: + if args: + dim0 = args[0] + if len(args) > 1: + unflattened_size = args[1] + else: + unflattened_size = kwargs.get("unflattened_size") + else: + dim0 = kwargs.get("dim") + unflattened_size = kwargs.get("unflattened_size") + if dim0 < 0: + dim0 = out.ndim + dim0 + dim1 = dim0 + len(unflattened_size) - 1 + return out.update(self.flatten(dim0, dim1)) elif last_op == self.__class__.permute.__name__: dims_list = _get_shape_from_args(*args, kwarg_name="dims", **kwargs) dims_list = [dim if dim >= 0 else self.ndim + dim for dim in dims_list] From 0c289c0c4a9d47b33aa0cc2008a07383680fa875 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 15 May 2024 12:50:08 +0100 Subject: [PATCH 2/4] amend --- test/test_tensordict.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 87264727b..8c0b49b9c 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -3191,6 +3191,35 @@ def test_flatten_unflatten(self, td_name, device): assert (td.to_tensordict() == td_unflat).all() assert td.batch_size == td_unflat.batch_size + @pytest.mark.parametrize("start_dim", [0, 1, -2, -3]) + def test_flatten_unflatten_decorator(self, td_name, device, start_dim): + td = getattr(self, td_name)(device) + with td.unlock_(), td.flatten(start_dim=start_dim, end_dim=3) as td_flat: + assert (td_flat == td.flatten(start_dim, 3)).all() + new_start_dim = -1 if start_dim in (-2, -3) else start_dim + with td_flat.unflatten( + dim=new_start_dim, unflattened_size=td.shape[start_dim:] + ) as td_unflat: + assert (td_unflat == td).all() + + with td.unlock_(), td.flatten(start_dim, end_dim=3) as td_flat: + assert (td_flat == td.flatten(start_dim, 3)).all() + new_start_dim = ( + -1 if start_dim == -2 else -1 if start_dim == -3 else start_dim + ) + with td_flat.unflatten( + new_start_dim, unflattened_size=td.shape[start_dim:] + ) as td_unflat: + assert (td_unflat == td).all() + + with td.unlock_(), td.flatten(start_dim, -1) as td_flat: + assert (td_flat == td.flatten(start_dim, -1)).all() + new_start_dim = ( + -1 if start_dim == -2 else -1 if start_dim == -3 else start_dim + ) + with td_flat.unflatten(new_start_dim, td.shape[start_dim:]) as td_unflat: + assert (td_unflat == td).all() + def test_flatten_unflatten_bis(self, td_name, device): td = getattr(self, td_name)(device) shape = td.shape[1:4] From bc26bcd68e8af4e85219aebc8c5d35528c0f36e8 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 15 May 2024 12:57:08 +0100 Subject: [PATCH 3/4] amend --- tensordict/_td.py | 11 ++++++++++- tensordict/base.py | 14 ++++++++------ 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/tensordict/_td.py b/tensordict/_td.py index 169af2192..e66acbc99 100644 --- a/tensordict/_td.py +++ b/tensordict/_td.py @@ -1866,7 +1866,16 @@ def _set_str( dest.update(value, inplace=True, non_blocking=non_blocking) else: if dest is not value: - dest.copy_(value, non_blocking=non_blocking) + try: + dest.copy_(value, non_blocking=non_blocking) + except RuntimeError: + # if we're updating a param and the storages match, nothing needs to be done + if not ( + isinstance(dest, torch.Tensor) + and dest.data.untyped_storage().data_ptr() + == value.data.untyped_storage().data_ptr() + ): + raise except KeyError as err: raise err except Exception as err: diff --git a/tensordict/base.py b/tensordict/base.py index 80f6c3582..9ec87ae67 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -6140,7 +6140,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): return self.lock_() elif last_op == self.__class__.transpose.__name__: dim0, dim1 = args - return out.update(self.transpose(dim0, dim1)) + return out.update(self.transpose(dim0, dim1), inplace=True) elif last_op == self.__class__.flatten.__name__: if len(args) == 2: dim0, dim1 = args @@ -6154,7 +6154,9 @@ def __exit__(self, exc_type, exc_val, exc_tb): dim1 = out.ndim + dim1 if dim0 < 0: dim0 = out.ndim + dim0 - return out.update(self.unflatten(dim0, out.shape[dim0 : dim1 + 1])) + return out.update( + self.unflatten(dim0, out.shape[dim0 : dim1 + 1]), inplace=True + ) elif last_op == self.__class__.unflatten.__name__: if args: dim0 = args[0] @@ -6168,7 +6170,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): if dim0 < 0: dim0 = out.ndim + dim0 dim1 = dim0 + len(unflattened_size) - 1 - return out.update(self.flatten(dim0, dim1)) + return out.update(self.flatten(dim0, dim1), inplace=True) elif last_op == self.__class__.permute.__name__: dims_list = _get_shape_from_args(*args, kwarg_name="dims", **kwargs) dims_list = [dim if dim >= 0 else self.ndim + dim for dim in dims_list] @@ -6176,7 +6178,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): inv_dims_list = np.argsort(dims_list) return out.update(self.permute(inv_dims_list)) elif last_op == self.__class__.view.__name__: - return out.update(self.view(out.shape)) + return out.update(self.view(out.shape), inplace=True) elif last_op == self.__class__.unsqueeze.__name__: if args: (dim,) = args @@ -6186,7 +6188,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): raise RuntimeError( "Cannot use td.unsqueeze() as a decorator if the dimension is implicit." ) - return out.update(self.squeeze(dim)) + return out.update(self.squeeze(dim), inplace=True) elif last_op == self.__class__.squeeze.__name__: if args: (dim,) = args @@ -6196,7 +6198,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): raise RuntimeError( "Cannot use td.squeeze() as a decorator if the dimension is implicit." ) - return out.update(self.unsqueeze(dim)) + return out.update(self.unsqueeze(dim), inplace=True) elif last_op == self.__class__.to_module.__name__: if is_tensor_collection(out): with out.unlock_(): From a8b1ebcea24af13031c4d96241d57a14adb30980 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 15 May 2024 13:02:52 +0100 Subject: [PATCH 4/4] amend --- tensordict/base.py | 42 +++++++++++++++++++++++++++++++++--------- 1 file changed, 33 insertions(+), 9 deletions(-) diff --git a/tensordict/base.py b/tensordict/base.py index 9ec87ae67..3d59b6390 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -6140,7 +6140,10 @@ def __exit__(self, exc_type, exc_val, exc_tb): return self.lock_() elif last_op == self.__class__.transpose.__name__: dim0, dim1 = args - return out.update(self.transpose(dim0, dim1), inplace=True) + if not out.is_locked: + return out.update(self.transpose(dim0, dim1), inplace=True) + else: + return out.update_(self.transpose(dim0, dim1)) elif last_op == self.__class__.flatten.__name__: if len(args) == 2: dim0, dim1 = args @@ -6154,9 +6157,14 @@ def __exit__(self, exc_type, exc_val, exc_tb): dim1 = out.ndim + dim1 if dim0 < 0: dim0 = out.ndim + dim0 - return out.update( - self.unflatten(dim0, out.shape[dim0 : dim1 + 1]), inplace=True - ) + + if not out.is_locked: + return out.update( + self.unflatten(dim0, out.shape[dim0 : dim1 + 1]), inplace=True + ) + else: + return out.update_(self.unflatten(dim0, out.shape[dim0 : dim1 + 1])) + elif last_op == self.__class__.unflatten.__name__: if args: dim0 = args[0] @@ -6170,15 +6178,25 @@ def __exit__(self, exc_type, exc_val, exc_tb): if dim0 < 0: dim0 = out.ndim + dim0 dim1 = dim0 + len(unflattened_size) - 1 - return out.update(self.flatten(dim0, dim1), inplace=True) + if not out.is_locked: + return out.update(self.flatten(dim0, dim1), inplace=True) + else: + return out.update_(self.flatten(dim0, dim1)) + elif last_op == self.__class__.permute.__name__: dims_list = _get_shape_from_args(*args, kwarg_name="dims", **kwargs) dims_list = [dim if dim >= 0 else self.ndim + dim for dim in dims_list] # inverse map inv_dims_list = np.argsort(dims_list) - return out.update(self.permute(inv_dims_list)) + if not out.is_locked: + return out.update(self.permute(inv_dims_list), inplace=True) + else: + return out.update_(self.permute(inv_dims_list)) elif last_op == self.__class__.view.__name__: - return out.update(self.view(out.shape), inplace=True) + if not out.is_locked: + return out.update(self.view(out.shape), inplace=True) + else: + return out.update_(self.view(out.shape)) elif last_op == self.__class__.unsqueeze.__name__: if args: (dim,) = args @@ -6188,7 +6206,10 @@ def __exit__(self, exc_type, exc_val, exc_tb): raise RuntimeError( "Cannot use td.unsqueeze() as a decorator if the dimension is implicit." ) - return out.update(self.squeeze(dim), inplace=True) + if not out.is_locked: + return out.update(self.squeeze(dim), inplace=True) + else: + return out.update_(self.squeeze(dim)) elif last_op == self.__class__.squeeze.__name__: if args: (dim,) = args @@ -6198,7 +6219,10 @@ def __exit__(self, exc_type, exc_val, exc_tb): raise RuntimeError( "Cannot use td.squeeze() as a decorator if the dimension is implicit." ) - return out.update(self.unsqueeze(dim), inplace=True) + if not out.is_locked: + return out.update(self.unsqueeze(dim), inplace=True) + else: + return out.update_(self.unsqueeze(dim)) elif last_op == self.__class__.to_module.__name__: if is_tensor_collection(out): with out.unlock_():