diff --git a/tensordict/nn/params.py b/tensordict/nn/params.py index f77ae6c50..9b24846f7 100644 --- a/tensordict/nn/params.py +++ b/tensordict/nn/params.py @@ -10,15 +10,20 @@ import re from copy import copy from functools import wraps -from typing import Any, Callable, Sequence +from typing import Any, Callable, Iterator, Sequence import torch from tensordict import TensorDictBase +from tensordict.nn.utils import Buffer from tensordict.tensordict import ( + _CustomOpTensorDict, + _is_tensor_collection, CompatibleType, + LazyStackedTensorDict, lock_blocked, NO_DEFAULT, + SubTensorDict, TD_HANDLED_FUNCTIONS, TensorDict, ) @@ -27,6 +32,35 @@ from torch.utils._pytree import tree_map +def _apply_leaves(data, fn): + if isinstance(data, TensorDict): + with data.unlock_(): + for key, val in list(data.items()): + data._set_str( + key, _apply_leaves(val, fn), validated=True, inplace=False + ) + return data + elif isinstance(data, LazyStackedTensorDict): + # this is currently not implemented as the registration of params will only work + # with plain TensorDict. The solution will be using pytree to get each independent + # leaf + raise RuntimeError( + "Using a LazyStackedTensorDict within a TensorDictParams isn't permitted." + ) + # for _data in data.tensordicts: + # _apply_leaves(_data, fn) + # return data + elif isinstance(data, _CustomOpTensorDict): + _apply_leaves(data._source, fn) + return data + elif isinstance(data, SubTensorDict): + raise RuntimeError( + "Using a SubTensorDict within a TensorDictParams isn't permitted." + ) + else: + return fn(data) + + def _get_args_dict(func, args, kwargs): signature = inspect.signature(func) bound_arguments = signature.bind(*args, **kwargs) @@ -46,6 +80,17 @@ def _maybe_make_param(tensor): return tensor +def _maybe_make_param_or_buffer(tensor): + if ( + isinstance(tensor, Tensor) + and not isinstance(tensor, nn.Parameter) + and tensor.dtype in (torch.float, torch.double, torch.half) + ): + # convert all non-parameters to buffers + tensor = Buffer(tensor) + return tensor + + class _unlock_and_set: def __new__(cls, *args, **kwargs): if len(args) and callable(args[0]): @@ -69,12 +114,17 @@ def new_func(_self, *args, **kwargs): meth = getattr(_self._param_td, name) out = meth(*args, **kwargs) return out - args = tree_map(_maybe_make_param, args) - kwargs = tree_map(_maybe_make_param, kwargs) + if not _self.no_convert: + args = tree_map(_maybe_make_param, args) + kwargs = tree_map(_maybe_make_param, kwargs) + else: + args = tree_map(_maybe_make_param_or_buffer, args) + kwargs = tree_map(_maybe_make_param_or_buffer, kwargs) + with _self._param_td.unlock_(): meth = getattr(_self._param_td, name) out = meth(*args, **kwargs) - _self.__dict__["_parameters"] = _self._param_td.flatten_keys("_").to_dict() + _self._reset_params() if out is _self._param_td: return _self return out @@ -82,6 +132,15 @@ def new_func(_self, *args, **kwargs): return new_func +def _get_post_hook(func): + @wraps(func) + def new_func(self, *args, **kwargs): + out = func(self, *args, **kwargs) + return self._apply_get_post_hook(out) + + return new_func + + def _fallback(func): name = func.__name__ @@ -105,7 +164,10 @@ def new_func(self): return self return out - return property(new_func) + def setter(self, value): + return getattr(type(self._param_td), name).fset(self._param_td, value) + + return property(new_func, setter) def _replace(func): @@ -128,7 +190,11 @@ def _carry_over(func): @wraps(func) def new_func(self, *args, **kwargs): out = getattr(self._param_td, name)(*args, **kwargs) - return TensorDictParams(out, no_convert=True) + print("out is", out) + out = TensorDictParams(out, no_convert=True) + print("out is (2)", out) + out.no_convert = self.no_convert + return out return new_func @@ -153,7 +219,10 @@ class TensorDictParams(TensorDictBase, nn.Module): Values will be converted to parameters unless ``no_convert=True``. Keyword Args: - no_convert (bool): if ``True``, no conversion to ``nn.Parameter`` will occur. + no_convert (bool): if ``True``, no conversion to ``nn.Parameter`` will + occur at construction and after (unless the ``no_convert`` attribute is changed). + If ``no_convert`` is ``True`` and if non-parameters are present, they + will be registered as buffers. Defaults to ``False``. Examples: @@ -196,15 +265,50 @@ class TensorDictParams(TensorDictBase, nn.Module): def __init__(self, parameters: TensorDictBase, *, no_convert=False): super().__init__() + if isinstance(parameters, TensorDictParams): + parameters = parameters._param_td self._param_td = parameters + self.no_convert = no_convert if not no_convert: - self._param_td = self._param_td.apply( - lambda x: _maybe_make_param(x) - ).lock_() - self._parameters = parameters.flatten_keys("_").to_dict() + func = _maybe_make_param + else: + func = _maybe_make_param_or_buffer + self._param_td = _apply_leaves(self._param_td, lambda x: func(x)).lock_() + self._reset_params() self._is_locked = False self._locked_tensordicts = [] self.__last_op_queue = None + self._get_post_hook = [] + + def register_get_post_hook(self, hook): + """Register a hook to be called after any get operation on leaf tensors.""" + if not callable(hook): + raise ValueError("Hooks must be callables.") + self._get_post_hook.append(hook) + + def _apply_get_post_hook(self, val): + if not _is_tensor_collection(type(val)): + for hook in self._get_post_hook: + new_val = hook(self, val) + if new_val is not None: + val = new_val + return val + + def _reset_params(self): + parameters = self._param_td + param_keys = [] + buffer_keys = [] + for key, value in parameters.items(True, True): + if isinstance(value, nn.Parameter): + param_keys.append(key) + else: + buffer_keys.append(key) + self.__dict__["_parameters"] = ( + parameters.select(*param_keys).flatten_keys("_").to_dict() + ) + self.__dict__["_buffers"] = ( + parameters.select(*buffer_keys).flatten_keys("_").to_dict() + ) @classmethod def __torch_function__( @@ -261,12 +365,17 @@ def update( clone: bool = False, inplace: bool = False, ) -> TensorDictBase: + if not self.no_convert: + func = _maybe_make_param + else: + func = _maybe_make_param_or_buffer if isinstance(input_dict_or_td, TensorDictBase): - input_dict_or_td = input_dict_or_td.apply(_maybe_make_param) + input_dict_or_td = input_dict_or_td.apply(func) else: - input_dict_or_td = tree_map(_maybe_make_param, input_dict_or_td) + input_dict_or_td = tree_map(func, input_dict_or_td) with self._param_td.unlock_(): TensorDictBase.update(self, input_dict_or_td, clone=clone, inplace=inplace) + self._reset_params() return self @lock_blocked @@ -300,16 +409,20 @@ def apply( ) -> TensorDictBase: ... + @_get_post_hook @_fallback def get( self, key: NestedKey, default: str | CompatibleType = NO_DEFAULT ) -> CompatibleType: ... + @_get_post_hook @_fallback def __getitem__(self, index: IndexType) -> TensorDictBase: ... + __getitems__ = __getitem__ + def to(self, dest: DeviceType | type | torch.Size, **kwargs) -> TensorDictBase: params = self._param_td.to(dest) if params is self._param_td: @@ -376,16 +489,26 @@ def _change_batch_size(self, *args, **kwargs): def _erase_names(self, *args, **kwargs): ... - # @_unlock_and_set # we need this as one sub-module could call _get_str, get a td and want to modify it + @_get_post_hook @_fallback def _get_str(self, *args, **kwargs): ... - # @_unlock_and_set + @_get_post_hook @_fallback def _get_tuple(self, *args, **kwargs): ... + @_get_post_hook + @_fallback + def _get_at_str(self, key, idx, default): + ... + + @_get_post_hook + @_fallback + def _get_at_tuple(self, key, idx, default): + ... + @_fallback def _has_names(self, *args, **kwargs): ... @@ -567,6 +690,112 @@ def create_nested(self, key): def __repr__(self): return f"TensorDictParams(params={self._param_td})" + def values( + self, include_nested: bool = False, leaves_only: bool = False + ) -> Iterator[CompatibleType]: + for v in self._param_td.values(include_nested, leaves_only): + if _is_tensor_collection(type(v)): + yield v + continue + yield self._apply_get_post_hook(v) + + def items( + self, include_nested: bool = False, leaves_only: bool = False + ) -> Iterator[CompatibleType]: + for k, v in self._param_td.items(include_nested, leaves_only): + if _is_tensor_collection(type(v)): + yield k, v + continue + yield k, self._apply_get_post_hook(v) + + def _apply(self, fn, recurse=True): + """Modifies torch.nn.Module._apply to work with Buffer class.""" + if recurse: + for module in self.children(): + module._apply(fn) + + def compute_should_use_set_data(tensor, tensor_applied): + if torch._has_compatible_shallow_copy_type(tensor, tensor_applied): + # If the new tensor has compatible tensor type as the existing tensor, + # the current behavior is to change the tensor in-place using `.data =`, + # and the future behavior is to overwrite the existing tensor. However, + # changing the current behavior is a BC-breaking change, and we want it + # to happen in future releases. So for now we introduce the + # `torch.__future__.get_overwrite_module_params_on_conversion()` + # global flag to let the user control whether they want the future + # behavior of overwriting the existing tensor or not. + return not torch.__future__.get_overwrite_module_params_on_conversion() + else: + return False + + for key, param in self._parameters.items(): + if param is None: + continue + # Tensors stored in modules are graph leaves, and we don't want to + # track autograd history of `param_applied`, so we have to use + # `with torch.no_grad():` + with torch.no_grad(): + param_applied = fn(param) + should_use_set_data = compute_should_use_set_data(param, param_applied) + if should_use_set_data: + param.data = param_applied + out_param = param + else: + assert isinstance(param, nn.Parameter) + assert param.is_leaf + out_param = nn.Parameter(param_applied, param.requires_grad) + self._parameters[key] = out_param + + if param.grad is not None: + with torch.no_grad(): + grad_applied = fn(param.grad) + should_use_set_data = compute_should_use_set_data( + param.grad, grad_applied + ) + if should_use_set_data: + assert out_param.grad is not None + out_param.grad.data = grad_applied + else: + assert param.grad.is_leaf + out_param.grad = grad_applied.requires_grad_( + param.grad.requires_grad + ) + + for key, buffer in self._buffers.items(): + if buffer is None: + continue + # Tensors stored in modules are graph leaves, and we don't want to + # track autograd history of `buffer_applied`, so we have to use + # `with torch.no_grad():` + with torch.no_grad(): + buffer_applied = fn(buffer) + should_use_set_data = compute_should_use_set_data(buffer, buffer_applied) + if should_use_set_data: + buffer.data = buffer_applied + out_buffer = buffer + else: + assert isinstance(buffer, Buffer) + assert buffer.is_leaf + out_buffer = Buffer(buffer_applied, buffer.requires_grad) + self._buffers[key] = out_buffer + + if buffer.grad is not None: + with torch.no_grad(): + grad_applied = fn(buffer.grad) + should_use_set_data = compute_should_use_set_data( + buffer.grad, grad_applied + ) + if should_use_set_data: + assert out_buffer.grad is not None + out_buffer.grad.data = grad_applied + else: + assert buffer.grad.is_leaf + out_buffer.grad = grad_applied.requires_grad_( + buffer.grad.requires_grad + ) + + return self + TDPARAM_HANDLED_FUNCTIONS = copy(TD_HANDLED_FUNCTIONS) diff --git a/tensordict/nn/utils.py b/tensordict/nn/utils.py index 831f4103d..fb7d5fe1c 100644 --- a/tensordict/nn/utils.py +++ b/tensordict/nn/utils.py @@ -7,7 +7,7 @@ import functools import inspect -from typing import Any, Callable +from typing import Any, Callable, OrderedDict import torch from torch import nn @@ -16,6 +16,7 @@ _SKIP_EXISTING = False from tensordict._contextlib import _DecoratorContextManager +from torch.nn.parameter import _disabled_torch_function_impl, _ParameterMeta def inv_softplus(bias: float | torch.Tensor) -> float | torch.Tensor: @@ -273,3 +274,58 @@ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None: def skip_existing(): """Returns whether or not existing entries in a tensordict should be re-computed by a module.""" return _SKIP_EXISTING + + +def _rebuild_buffer(data, requires_grad, backward_hooks): + buffer = Buffer(data, requires_grad) + # NB: This line exists only for backwards compatibility; the + # general expectation is that backward_hooks is an empty + # OrderedDict. See Note [Don't serialize hooks] + buffer._backward_hooks = backward_hooks + + return buffer + + +class Buffer(torch.Tensor, metaclass=_ParameterMeta): + r"""A kind of Tensor that is to be considered a module buffer. + + Args: + data (Tensor): buffer tensor. + requires_grad (bool, optional): if the buffer requires gradient. See + :ref:`locally-disable-grad-doc` for more details. Default: `False` + """ + + def __new__(cls, data=None, requires_grad=False): + if data is None: + data = torch.empty(0) + if type(data) is torch.Tensor or type(data) is Buffer: + # For ease of BC maintenance, keep this path for standard Tensor. + # Eventually (tm), we should change the behavior for standard Tensor to match. + return torch.Tensor._make_subclass(cls, data, requires_grad) + + # Path for custom tensors: set a flag on the instance to indicate parameter-ness. + t = data.detach().requires_grad_(requires_grad) + t._is_buffer = True + return t + + def __deepcopy__(self, memo): + if id(self) in memo: + return memo[id(self)] + else: + result = type(self)( + self.data.clone(memory_format=torch.preserve_format), self.requires_grad + ) + memo[id(self)] = result + return result + + def __repr__(self): + return "Buffer containing:\n" + super(Buffer, self).__repr__() + + def __reduce_ex__(self, proto): + # See Note [Don't serialize hooks] + return ( + torch._utils._rebuild_parameter, + (self.data, self.requires_grad, OrderedDict()), + ) + + __torch_function__ = _disabled_torch_function_impl diff --git a/test/test_nn.py b/test/test_nn.py index 06e721c25..5e72b9983 100644 --- a/test/test_nn.py +++ b/test/test_nn.py @@ -2792,9 +2792,14 @@ def _get_params(self): return params class CustomModule(nn.Module): - def __init__(self, params): + def __init__(self, *params): super().__init__() - self.params = params + if len(params) == 1: + params = params[0] + self.params = params + else: + for i, p in enumerate(params): + setattr(self, f"params{i}", p) def test_td_params(self): params = self._get_params() @@ -2832,11 +2837,44 @@ def test_td_params_cast(self): params = self._get_params() p = TensorDictParams(params) m = self.CustomModule(p) + print("m.children", list(m.children())) for dtype in ("half", "double", "float"): getattr(m, dtype)() for p in params.values(True, True): assert p.dtype == getattr(torch, dtype) + def test_td_params_tying(self): + params = self._get_params() + p1 = TensorDictParams(params) + p2 = TensorDictParams(params) + m = self.CustomModule(p1, p2) + for key in dict(m.named_parameters()).keys(): + assert key.startswith("params0") + + def test_td_params_post_hook(self): + hook = lambda self, x: x.data + td = TensorDict( + { + "a": { + "b": {"c": torch.zeros((), requires_grad=True)}, + "d": torch.zeros((), requires_grad=True), + }, + "e": torch.zeros((), requires_grad=True), + }, + [], + ) + param_td = TensorDictParams(td) + param_td.register_get_post_hook(hook) + assert all(p.requires_grad for p in td.values(True, True)) + assert all(not p.requires_grad for p in param_td.values(True, True)) + assert {p.data.data_ptr() for p in param_td.values(True, True)} == { + p.data.data_ptr() for p in td.values(True, True) + } + assert not param_td["e"].requires_grad + assert not param_td["a", "b", "c"].requires_grad + assert not param_td.get("e").requires_grad + assert not param_td.get(("a", "b", "c")).requires_grad + if __name__ == "__main__": args, unknown = argparse.ArgumentParser().parse_known_args()