Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] zero_grad and requires_grad_ #901

Merged
merged 2 commits into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1343,6 +1343,23 @@ def grad(self):
"""Returns a tensordict containing the .grad attributes of the leaf tensors."""
return self._grad()

def zero_grad(self, set_to_none: bool = True) -> T:
"""Zeros all the gradients of the TensorDict recursively.

Args:
set_to_none (bool, optional): if ``True``, tensor.grad will be ``None``,
otherwise ``0``.
Defaults to ``True``.

"""
if set_to_none:
for val in self._values_list(True, True, is_leaf=_NESTED_TENSORS_AS_LISTS):
val.grad = None
return
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldnt we return self here too ?

for val in self._values_list(True, True, is_leaf=_NESTED_TENSORS_AS_LISTS):
val.grad.zero_()
return self

@cache # noqa
def _dtype(self):
dtype = None
Expand Down Expand Up @@ -5082,7 +5099,9 @@ def _items_list(

@cache # noqa: B019
def _grad(self):
result = self._fast_apply(lambda x: x.grad, propagate_lock=True)
result = self._fast_apply(
lambda x: x.grad, propagate_lock=True, filter_empty=True
)
return result

@cache # noqa: B019
Expand Down Expand Up @@ -9207,6 +9226,20 @@ def type(self, dst_type):
def requires_grad(self) -> bool:
return any(v.requires_grad for v in self.values())

def requires_grad_(self, requires_grad=True) -> T:
"""Change if autograd should record operations on this tensor: sets this tensor’s requires_grad attribute in-place.

Returns this tensordict.

Args:
requires_grad (bool, optional): whether or not autograd should record operations on this tensordict.
Defaults to ``True``.

"""
for val in self._values_list(True, True, is_leaf=_NESTED_TENSORS_AS_LISTS):
val.requires_grad_(requires_grad)
return self

@abc.abstractmethod
def detach_(self) -> T:
"""Detach the tensors in the tensordict in-place.
Expand Down
12 changes: 12 additions & 0 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def __subclasscheck__(self, subclass):
"asin_",
"atan",
"atan_",
"auto_batch_size_",
"ceil",
"ceil_",
"clamp_max",
Expand Down Expand Up @@ -227,11 +228,14 @@ def __subclasscheck__(self, subclass):
"masked_fill_",
"maximum",
"maximum_",
"mean",
"minimum",
"minimum_",
"mul",
"mul_",
"named_apply",
"nanmean",
"nansum",
"neg",
"neg_",
"new_empty",
Expand All @@ -243,9 +247,11 @@ def __subclasscheck__(self, subclass):
"permute",
"pow",
"pow_",
"prod",
"reciprocal",
"reciprocal_",
"refine_names",
"requires_grad_",
"rename_", # TODO: must be specialized
"replace",
"reshape",
Expand All @@ -263,8 +269,10 @@ def __subclasscheck__(self, subclass):
"sqrt",
"sqrt_",
"squeeze",
"std",
"sub",
"sub_",
"sum",
"tan",
"tan_",
"tanh",
Expand All @@ -276,9 +284,11 @@ def __subclasscheck__(self, subclass):
"unflatten",
"unlock_",
"unsqueeze",
"var",
"view",
"where",
"zero_",
"zero_grad",
]
assert not any(v in _METHOD_FROM_TD for v in _FALLBACK_METHOD_FROM_TD), set(
_METHOD_FROM_TD
Expand Down Expand Up @@ -520,6 +530,8 @@ def __torch_function__(
cls.batch_size = property(_batch_size, _batch_size_setter)
if not hasattr(cls, "names"):
cls.names = property(_names, _names_setter)
if not hasattr(cls, "names"):
cls.require = property(_names, _names_setter)
if not hasattr(cls, "to_dict"):
cls.to_dict = _to_dict

Expand Down
13 changes: 13 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -6201,6 +6201,19 @@ def test_zero_(self, td_name, device):
for k in td.keys():
assert (td.get(k) == 0).all()

@pytest.mark.parametrize("set_to_none", [True, False])
def test_zero_grad(self, td_name, device, set_to_none):
td = getattr(self, td_name)(device)
tdr = td.float().requires_grad_()
td1 = tdr + 1
sum(td1.sum().values(True, True)).backward()
assert (tdr.grad == 1).all(), tdr.grad.to_dict()
tdr.zero_grad(set_to_none=set_to_none)
if set_to_none:
assert tdr.filter_non_tensor_data().grad is None, (td, tdr, tdr.grad)
else:
assert (tdr.grad == 0).all()


@pytest.mark.parametrize("device", [None, *get_available_devices()])
@pytest.mark.parametrize("dtype", [torch.float32, torch.uint8])
Expand Down
Loading