diff --git a/tensordict/nn/params.py b/tensordict/nn/params.py index f0a716fe8..6cedbf41d 100644 --- a/tensordict/nn/params.py +++ b/tensordict/nn/params.py @@ -234,6 +234,11 @@ class TensorDictParams(TensorDictBase, nn.Module): If ``no_convert`` is ``True`` and if non-parameters are present, they will be registered as buffers. Defaults to ``False``. + lock (bool): if ``True``, the tensordict hosted by TensorDictParams will + be locked. This can be useful to avoid unwanted modifications, but + also restricts the operations that can be done over the object (and + can have significant performance impact when `unlock_()` is required). + Defaults to ``False``. Examples: >>> from torch import nn @@ -273,7 +278,9 @@ class TensorDictParams(TensorDictBase, nn.Module): """ - def __init__(self, parameters: TensorDictBase, *, no_convert=False): + def __init__( + self, parameters: TensorDictBase, *, no_convert=False, lock: bool = False + ): super().__init__() if isinstance(parameters, TensorDictParams): parameters = parameters._param_td @@ -283,7 +290,10 @@ def __init__(self, parameters: TensorDictBase, *, no_convert=False): func = _maybe_make_param else: func = _maybe_make_param_or_buffer - self._param_td = _apply_leaves(self._param_td, lambda x: func(x)).lock_() + self._param_td = _apply_leaves(self._param_td, lambda x: func(x)) + self._lock = lock + if lock: + self._param_td.lock_() self._reset_params() self._is_locked = False self._locked_tensordicts = [] @@ -307,18 +317,21 @@ def _apply_get_post_hook(self, val): def _reset_params(self): parameters = self._param_td param_keys = [] + params = [] buffer_keys = [] + buffers = [] for key, value in parameters.items(True, True): + # flatten key + if isinstance(key, tuple): + key = "_".join(key) if isinstance(value, nn.Parameter): param_keys.append(key) + params.append(value) else: buffer_keys.append(key) - self.__dict__["_parameters"] = ( - parameters.select(*param_keys).flatten_keys("_").to_dict() - ) - self.__dict__["_buffers"] = ( - parameters.select(*buffer_keys).flatten_keys("_").to_dict() - ) + buffers.append(value) + self.__dict__["_parameters"] = dict(zip(param_keys, params)) + self.__dict__["_buffers"] = dict(zip(buffer_keys, buffers)) @classmethod def __torch_function__(