From 05fbddd82128db4a2860cca4a7cf8f4667d421d4 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 18 Jul 2024 15:28:29 +0100 Subject: [PATCH 1/2] init --- tensordict/base.py | 35 ++++++++++++++++++++++++++++++++++- tensordict/tensorclass.py | 11 +++++++++++ test/test_tensordict.py | 13 +++++++++++++ 3 files changed, 58 insertions(+), 1 deletion(-) diff --git a/tensordict/base.py b/tensordict/base.py index b522d43fe..d8ae78294 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -1343,6 +1343,23 @@ def grad(self): """Returns a tensordict containing the .grad attributes of the leaf tensors.""" return self._grad() + def zero_grad(self, set_to_none: bool = True) -> T: + """Zeros all the gradients of the TensorDict recursively. + + Args: + set_to_none (bool, optional): if ``True``, tensor.grad will be ``None``, + otherwise ``0``. + Defaults to ``True``. + + """ + if set_to_none: + for val in self._values_list(True, True, is_leaf=_NESTED_TENSORS_AS_LISTS): + val.grad = None + return + for val in self._values_list(True, True, is_leaf=_NESTED_TENSORS_AS_LISTS): + val.grad.zero_() + return self + @cache # noqa def _dtype(self): dtype = None @@ -5082,7 +5099,9 @@ def _items_list( @cache # noqa: B019 def _grad(self): - result = self._fast_apply(lambda x: x.grad, propagate_lock=True) + result = self._fast_apply( + lambda x: x.grad, propagate_lock=True, filter_empty=True + ) return result @cache # noqa: B019 @@ -9207,6 +9226,20 @@ def type(self, dst_type): def requires_grad(self) -> bool: return any(v.requires_grad for v in self.values()) + def requires_grad_(self, requires_grad=True) -> T: + """Change if autograd should record operations on this tensor: sets this tensor’s requires_grad attribute in-place. + + Returns this tensordict. + + Args: + requires_grad (bool, optional): whether or not autograd should record operations on this tensordict. + Defaults to ``True``. + + """ + for val in self._values_list(True, True, is_leaf=_NESTED_TENSORS_AS_LISTS): + val.requires_grad_(requires_grad) + return self + @abc.abstractmethod def detach_(self) -> T: """Detach the tensors in the tensordict in-place. diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 1f5299cf9..0c9e5f7bb 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -169,6 +169,7 @@ def __subclasscheck__(self, subclass): "asin_", "atan", "atan_", + "auto_batch_size_", "ceil", "ceil_", "clamp_max", @@ -227,11 +228,14 @@ def __subclasscheck__(self, subclass): "masked_fill_", "maximum", "maximum_", + "mean", "minimum", "minimum_", "mul", "mul_", "named_apply", + "nanmean", + "nansum", "neg", "neg_", "new_empty", @@ -243,9 +247,12 @@ def __subclasscheck__(self, subclass): "permute", "pow", "pow_", + "prod", "reciprocal", "reciprocal_", "refine_names", + "requires_grad", + "requires_grad_", "rename_", # TODO: must be specialized "replace", "reshape", @@ -263,8 +270,10 @@ def __subclasscheck__(self, subclass): "sqrt", "sqrt_", "squeeze", + "std", "sub", "sub_", + "sum", "tan", "tan_", "tanh", @@ -276,9 +285,11 @@ def __subclasscheck__(self, subclass): "unflatten", "unlock_", "unsqueeze", + "var", "view", "where", "zero_", + "zero_grad", ] assert not any(v in _METHOD_FROM_TD for v in _FALLBACK_METHOD_FROM_TD), set( _METHOD_FROM_TD diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 511996bc2..95cdf57c3 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -6201,6 +6201,19 @@ def test_zero_(self, td_name, device): for k in td.keys(): assert (td.get(k) == 0).all() + @pytest.mark.parametrize("set_to_none", [True, False]) + def test_zero_grad(self, td_name, device, set_to_none): + td = getattr(self, td_name)(device) + tdr = td.float().requires_grad_() + td1 = tdr + 1 + sum(td1.sum().values(True, True)).backward() + assert (tdr.grad == 1).all(), tdr.grad.to_dict() + tdr.zero_grad(set_to_none=set_to_none) + if set_to_none: + assert tdr.filter_non_tensor_data().grad is None, (td, tdr, tdr.grad) + else: + assert (tdr.grad == 0).all() + @pytest.mark.parametrize("device", [None, *get_available_devices()]) @pytest.mark.parametrize("dtype", [torch.float32, torch.uint8]) From 138ef965cbaa30510b9d01e4bc7cc55c6082fda5 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Thu, 18 Jul 2024 21:04:29 +0100 Subject: [PATCH 2/2] amend --- tensordict/tensorclass.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index 0c9e5f7bb..1c1a63cf1 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -251,7 +251,6 @@ def __subclasscheck__(self, subclass): "reciprocal", "reciprocal_", "refine_names", - "requires_grad", "requires_grad_", "rename_", # TODO: must be specialized "replace", @@ -531,6 +530,8 @@ def __torch_function__( cls.batch_size = property(_batch_size, _batch_size_setter) if not hasattr(cls, "names"): cls.names = property(_names, _names_setter) + if not hasattr(cls, "names"): + cls.require = property(_names, _names_setter) if not hasattr(cls, "to_dict"): cls.to_dict = _to_dict