Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Hooks and Buffers for TensorDictParams #502

Merged
merged 4 commits into from
Jul 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
259 changes: 244 additions & 15 deletions tensordict/nn/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,20 @@
import re
from copy import copy
from functools import wraps
from typing import Any, Callable, Sequence
from typing import Any, Callable, Iterator, Sequence

import torch

from tensordict import TensorDictBase
from tensordict.nn.utils import Buffer
from tensordict.tensordict import (
_CustomOpTensorDict,
_is_tensor_collection,
CompatibleType,
LazyStackedTensorDict,
lock_blocked,
NO_DEFAULT,
SubTensorDict,
TD_HANDLED_FUNCTIONS,
TensorDict,
)
Expand All @@ -27,6 +32,35 @@
from torch.utils._pytree import tree_map


def _apply_leaves(data, fn):
if isinstance(data, TensorDict):
with data.unlock_():
for key, val in list(data.items()):
data._set_str(
key, _apply_leaves(val, fn), validated=True, inplace=False
)
return data
elif isinstance(data, LazyStackedTensorDict):
# this is currently not implemented as the registration of params will only work
# with plain TensorDict. The solution will be using pytree to get each independent
# leaf
raise RuntimeError(
"Using a LazyStackedTensorDict within a TensorDictParams isn't permitted."
)
# for _data in data.tensordicts:
# _apply_leaves(_data, fn)
# return data
elif isinstance(data, _CustomOpTensorDict):
_apply_leaves(data._source, fn)
return data
elif isinstance(data, SubTensorDict):
raise RuntimeError(
"Using a SubTensorDict within a TensorDictParams isn't permitted."
)
else:
return fn(data)


def _get_args_dict(func, args, kwargs):
signature = inspect.signature(func)
bound_arguments = signature.bind(*args, **kwargs)
Expand All @@ -46,6 +80,17 @@ def _maybe_make_param(tensor):
return tensor


def _maybe_make_param_or_buffer(tensor):
if (
isinstance(tensor, Tensor)
and not isinstance(tensor, nn.Parameter)
and tensor.dtype in (torch.float, torch.double, torch.half)
):
# convert all non-parameters to buffers
tensor = Buffer(tensor)
return tensor


class _unlock_and_set:
def __new__(cls, *args, **kwargs):
if len(args) and callable(args[0]):
Expand All @@ -69,19 +114,33 @@ def new_func(_self, *args, **kwargs):
meth = getattr(_self._param_td, name)
out = meth(*args, **kwargs)
return out
args = tree_map(_maybe_make_param, args)
kwargs = tree_map(_maybe_make_param, kwargs)
if not _self.no_convert:
args = tree_map(_maybe_make_param, args)
kwargs = tree_map(_maybe_make_param, kwargs)
else:
args = tree_map(_maybe_make_param_or_buffer, args)
kwargs = tree_map(_maybe_make_param_or_buffer, kwargs)

with _self._param_td.unlock_():
meth = getattr(_self._param_td, name)
out = meth(*args, **kwargs)
_self.__dict__["_parameters"] = _self._param_td.flatten_keys("_").to_dict()
_self._reset_params()
if out is _self._param_td:
return _self
return out

return new_func


def _get_post_hook(func):
@wraps(func)
def new_func(self, *args, **kwargs):
out = func(self, *args, **kwargs)
return self._apply_get_post_hook(out)

return new_func


def _fallback(func):
name = func.__name__

Expand All @@ -105,7 +164,10 @@ def new_func(self):
return self
return out

return property(new_func)
def setter(self, value):
return getattr(type(self._param_td), name).fset(self._param_td, value)

return property(new_func, setter)


def _replace(func):
Expand All @@ -128,7 +190,11 @@ def _carry_over(func):
@wraps(func)
def new_func(self, *args, **kwargs):
out = getattr(self._param_td, name)(*args, **kwargs)
return TensorDictParams(out, no_convert=True)
print("out is", out)
out = TensorDictParams(out, no_convert=True)
print("out is (2)", out)
out.no_convert = self.no_convert
return out

return new_func

Expand All @@ -153,7 +219,10 @@ class TensorDictParams(TensorDictBase, nn.Module):
Values will be converted to parameters unless ``no_convert=True``.

Keyword Args:
no_convert (bool): if ``True``, no conversion to ``nn.Parameter`` will occur.
no_convert (bool): if ``True``, no conversion to ``nn.Parameter`` will
occur at construction and after (unless the ``no_convert`` attribute is changed).
If ``no_convert`` is ``True`` and if non-parameters are present, they
will be registered as buffers.
Defaults to ``False``.

Examples:
Expand Down Expand Up @@ -196,15 +265,50 @@ class TensorDictParams(TensorDictBase, nn.Module):

def __init__(self, parameters: TensorDictBase, *, no_convert=False):
super().__init__()
if isinstance(parameters, TensorDictParams):
parameters = parameters._param_td
self._param_td = parameters
self.no_convert = no_convert
if not no_convert:
self._param_td = self._param_td.apply(
lambda x: _maybe_make_param(x)
).lock_()
self._parameters = parameters.flatten_keys("_").to_dict()
func = _maybe_make_param
else:
func = _maybe_make_param_or_buffer
self._param_td = _apply_leaves(self._param_td, lambda x: func(x)).lock_()
self._reset_params()
self._is_locked = False
self._locked_tensordicts = []
self.__last_op_queue = None
self._get_post_hook = []

def register_get_post_hook(self, hook):
"""Register a hook to be called after any get operation on leaf tensors."""
if not callable(hook):
raise ValueError("Hooks must be callables.")
self._get_post_hook.append(hook)

def _apply_get_post_hook(self, val):
if not _is_tensor_collection(type(val)):
for hook in self._get_post_hook:
new_val = hook(self, val)
if new_val is not None:
val = new_val
return val

def _reset_params(self):
parameters = self._param_td
param_keys = []
buffer_keys = []
for key, value in parameters.items(True, True):
if isinstance(value, nn.Parameter):
param_keys.append(key)
else:
buffer_keys.append(key)
self.__dict__["_parameters"] = (
parameters.select(*param_keys).flatten_keys("_").to_dict()
)
self.__dict__["_buffers"] = (
parameters.select(*buffer_keys).flatten_keys("_").to_dict()
)

@classmethod
def __torch_function__(
Expand Down Expand Up @@ -261,12 +365,17 @@ def update(
clone: bool = False,
inplace: bool = False,
) -> TensorDictBase:
if not self.no_convert:
func = _maybe_make_param
else:
func = _maybe_make_param_or_buffer
if isinstance(input_dict_or_td, TensorDictBase):
input_dict_or_td = input_dict_or_td.apply(_maybe_make_param)
input_dict_or_td = input_dict_or_td.apply(func)
else:
input_dict_or_td = tree_map(_maybe_make_param, input_dict_or_td)
input_dict_or_td = tree_map(func, input_dict_or_td)
with self._param_td.unlock_():
TensorDictBase.update(self, input_dict_or_td, clone=clone, inplace=inplace)
self._reset_params()
return self

@lock_blocked
Expand Down Expand Up @@ -300,16 +409,20 @@ def apply(
) -> TensorDictBase:
...

@_get_post_hook
@_fallback
def get(
self, key: NestedKey, default: str | CompatibleType = NO_DEFAULT
) -> CompatibleType:
...

@_get_post_hook
@_fallback
def __getitem__(self, index: IndexType) -> TensorDictBase:
...

__getitems__ = __getitem__

def to(self, dest: DeviceType | type | torch.Size, **kwargs) -> TensorDictBase:
params = self._param_td.to(dest)
if params is self._param_td:
Expand Down Expand Up @@ -376,16 +489,26 @@ def _change_batch_size(self, *args, **kwargs):
def _erase_names(self, *args, **kwargs):
...

# @_unlock_and_set # we need this as one sub-module could call _get_str, get a td and want to modify it
@_get_post_hook
@_fallback
def _get_str(self, *args, **kwargs):
...

# @_unlock_and_set
@_get_post_hook
@_fallback
def _get_tuple(self, *args, **kwargs):
...

@_get_post_hook
@_fallback
def _get_at_str(self, key, idx, default):
...

@_get_post_hook
@_fallback
def _get_at_tuple(self, key, idx, default):
...

@_fallback
def _has_names(self, *args, **kwargs):
...
Expand Down Expand Up @@ -567,6 +690,112 @@ def create_nested(self, key):
def __repr__(self):
return f"TensorDictParams(params={self._param_td})"

def values(
self, include_nested: bool = False, leaves_only: bool = False
) -> Iterator[CompatibleType]:
for v in self._param_td.values(include_nested, leaves_only):
if _is_tensor_collection(type(v)):
yield v
continue
yield self._apply_get_post_hook(v)

def items(
self, include_nested: bool = False, leaves_only: bool = False
) -> Iterator[CompatibleType]:
for k, v in self._param_td.items(include_nested, leaves_only):
if _is_tensor_collection(type(v)):
yield k, v
continue
yield k, self._apply_get_post_hook(v)

def _apply(self, fn, recurse=True):
"""Modifies torch.nn.Module._apply to work with Buffer class."""
if recurse:
for module in self.children():
module._apply(fn)

def compute_should_use_set_data(tensor, tensor_applied):
if torch._has_compatible_shallow_copy_type(tensor, tensor_applied):
# If the new tensor has compatible tensor type as the existing tensor,
# the current behavior is to change the tensor in-place using `.data =`,
# and the future behavior is to overwrite the existing tensor. However,
# changing the current behavior is a BC-breaking change, and we want it
# to happen in future releases. So for now we introduce the
# `torch.__future__.get_overwrite_module_params_on_conversion()`
# global flag to let the user control whether they want the future
# behavior of overwriting the existing tensor or not.
return not torch.__future__.get_overwrite_module_params_on_conversion()
else:
return False

for key, param in self._parameters.items():
if param is None:
continue
# Tensors stored in modules are graph leaves, and we don't want to
# track autograd history of `param_applied`, so we have to use
# `with torch.no_grad():`
with torch.no_grad():
param_applied = fn(param)
should_use_set_data = compute_should_use_set_data(param, param_applied)
if should_use_set_data:
param.data = param_applied
out_param = param
else:
assert isinstance(param, nn.Parameter)
assert param.is_leaf
out_param = nn.Parameter(param_applied, param.requires_grad)
self._parameters[key] = out_param

if param.grad is not None:
with torch.no_grad():
grad_applied = fn(param.grad)
should_use_set_data = compute_should_use_set_data(
param.grad, grad_applied
)
if should_use_set_data:
assert out_param.grad is not None
out_param.grad.data = grad_applied
else:
assert param.grad.is_leaf
out_param.grad = grad_applied.requires_grad_(
param.grad.requires_grad
)

for key, buffer in self._buffers.items():
if buffer is None:
continue
# Tensors stored in modules are graph leaves, and we don't want to
# track autograd history of `buffer_applied`, so we have to use
# `with torch.no_grad():`
with torch.no_grad():
buffer_applied = fn(buffer)
should_use_set_data = compute_should_use_set_data(buffer, buffer_applied)
if should_use_set_data:
buffer.data = buffer_applied
out_buffer = buffer
else:
assert isinstance(buffer, Buffer)
assert buffer.is_leaf
out_buffer = Buffer(buffer_applied, buffer.requires_grad)
self._buffers[key] = out_buffer

if buffer.grad is not None:
with torch.no_grad():
grad_applied = fn(buffer.grad)
should_use_set_data = compute_should_use_set_data(
buffer.grad, grad_applied
)
if should_use_set_data:
assert out_buffer.grad is not None
out_buffer.grad.data = grad_applied
else:
assert buffer.grad.is_leaf
out_buffer.grad = grad_applied.requires_grad_(
buffer.grad.requires_grad
)

return self


TDPARAM_HANDLED_FUNCTIONS = copy(TD_HANDLED_FUNCTIONS)

Expand Down
Loading