diff --git a/tensordict/base.py b/tensordict/base.py index 380848ab8..f15fd3737 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -1089,6 +1089,7 @@ def from_modules( lock: bool = True, use_state_dict: bool = False, lazy_stack: bool = False, + expand_identical: bool = False, ): """Retrieves the parameters of several modules for ensebmle learning/feature of expects applications through vmap. @@ -1134,6 +1135,9 @@ def from_modules( or :meth:`~torch.optim.Optimizer.zero_grad` will take longer to be executed. In general, ``lazy_stack`` should be reserved to very few use cases. + expand_identical (bool, optional): if ``True`` and the same parameter (same + identity) is being stacked to itself, an expanded version of this parameter + will be returned instead. This argument is ignored when ``lazy_stack=True``. Examples: >>> from torch import nn @@ -1200,6 +1204,36 @@ def from_modules( "lasy_stack=True is not compatible with lazy modules." ) params = LazyStackedTensorDict.lazy_stack(param_list) + elif expand_identical: + from tensordict._torch_func import _stack_uninit_params + + # Check the keys + # If not expand_identical, `stack` takes care of that check but + # here we use apply which will ignore keys that are in one TD but not another + sets = [set(param.keys(True, True)) for param in param_list] + for set_ in sets[1:]: + if set_ != sets[0]: + raise ValueError( + f"All key sets must match. " + f"Got {set_.symmetric_difference(sets[0])} in one but not another." + ) + + def maybe_stack(*params): + param = params[0] + if isinstance(param, UninitializedTensorMixin): + return _stack_uninit_params(params, 0) + if len(set(params)) == 1: + return param.expand((len(params), *param.shape)) + result = torch.stack(params) + if isinstance(param, nn.Parameter): + return nn.Parameter(result.detach(), param.requires_grad) + return Buffer(result) + + params = param_list[0]._fast_apply( + maybe_stack, + *param_list[1:], + batch_size=torch.Size([len(param_list), *param_list[0].batch_size]), + ) else: with set_lazy_legacy(False), torch.no_grad(): params = torch.stack(param_list) diff --git a/tensordict/nn/params.py b/tensordict/nn/params.py index c89cba5dd..7c20cab79 100644 --- a/tensordict/nn/params.py +++ b/tensordict/nn/params.py @@ -112,7 +112,9 @@ def _maybe_make_param_or_buffer(tensor): and tensor.dtype in (torch.float, torch.double, torch.half) ): # convert all non-parameters to buffers + # dataptr = tensor.data.data_ptr() tensor = Buffer(tensor) + # assert tensor.data.data_ptr() == dataptr return tensor diff --git a/tensordict/utils.py b/tensordict/utils.py index ccd2a764a..d4c35e99d 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -1853,13 +1853,11 @@ class Buffer(Tensor, metaclass=_ParameterMeta): def __new__(cls, data=None, requires_grad=False): if data is None: data = torch.empty(0) - if type(data) is Tensor or type(data) is Buffer: - # For ease of BC maintenance, keep this path for standard Tensor. - # Eventually (tm), we should change the behavior for standard Tensor to match. - return Tensor._make_subclass(cls, data, requires_grad) - # Path for custom tensors: set a flag on the instance to indicate parameter-ness. - t = data.detach().requires_grad_(requires_grad) + if requires_grad: + t = data.detach().requires_grad_(requires_grad) + else: + t = data t._is_buffer = True return t diff --git a/test/test_tensordict.py b/test/test_tensordict.py index 4767778a4..8a782fdb5 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -982,6 +982,39 @@ def get_leaf(leaf): assert p.grad is None assert all(param.grad is not None for param in params.values(True, True)) + @pytest.mark.parametrize("as_module", [False, True]) + def test_from_modules_expand(self, as_module): + empty_module = nn.Sequential( + nn.Linear(3, 3, device="meta"), nn.Linear(3, 4, device="meta") + ) + module0 = nn.Linear(3, 3) + modules = [nn.Sequential(module0, nn.Linear(3, 4)) for _ in range(3)] + params = TensorDict.from_modules( + *modules, as_module=as_module, expand_identical=True + ) + assert not isinstance(params["0", "weight"], nn.Parameter) + assert params["0", "weight"].data.data_ptr() == module0.weight.data.data_ptr() + assert isinstance(params["1", "weight"], nn.Parameter) + assert ( + params["1", "weight"].data.data_ptr() + != modules[0][1].weight.data.data_ptr() + ) + + def exec_module(params, x): + with params.to_module(empty_module): + return empty_module(x) + + x = torch.zeros(3) + y = torch.vmap(exec_module, (0, None))(params, x) + y.sum().backward() + for k, p in modules[0].named_parameters(): + assert p.grad is None if k.startswith("1") else p.grad is not None + assert all( + param.grad is not None + for param in params.values(True, True) + if isinstance(param, nn.Parameter) + ) + @pytest.mark.parametrize("as_module", [False, True]) @pytest.mark.parametrize("lazy_stack", [False, True]) @pytest.mark.parametrize("device", get_available_devices())