From bab6ea3d9fad979e727ec365c5933936ca1e51b8 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Tue, 10 Sep 2024 15:49:57 +0100 Subject: [PATCH] Update [ghstack-poisoned] --- tensordict/tensorclass.py | 40 ++++++++++++++++++++++++---------- test/test_tensorclass.py | 46 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 75 insertions(+), 11 deletions(-) diff --git a/tensordict/tensorclass.py b/tensordict/tensorclass.py index fa157a831..975c1b1bf 100644 --- a/tensordict/tensorclass.py +++ b/tensordict/tensorclass.py @@ -338,30 +338,35 @@ def is_non_tensor(obj): class _tensorclass_dec: - def __new__(cls, autocast: bool = False): + def __new__(cls, autocast: bool = False, frozen: bool = False): if not isinstance(autocast, bool): clz = autocast self = super().__new__(cls) - self.__init__(autocast=False) + self.__init__(autocast=False, frozen=False) return self.__call__(clz) return super().__new__(cls) - def __init__(self, autocast: bool): + def __init__(self, autocast: bool = False, frozen: bool = False): self.autocast = autocast + self.frozen = frozen @dataclass_transform() def __call__(self, cls): - clz = _tensorclass(cls) + clz = _tensorclass(cls, frozen=self.frozen) clz.autocast = self.autocast return clz @overload -def tensorclass(autocast: bool = False) -> _tensorclass_dec: ... +def tensorclass(autocast: bool = False, frozen: bool = False) -> _tensorclass_dec: ... + + +@overload +def tensorclass(cls: T) -> T: ... @dataclass_transform() -def tensorclass(cls: T) -> T: +def tensorclass(*args, **kwargs): """A decorator to create :obj:`tensorclass` classes. ``tensorclass`` classes are specialized :func:`dataclasses.dataclass` instances that @@ -372,6 +377,9 @@ def tensorclass(cls: T) -> T: Args: autocast (bool, optional): if ``True``, the types indicated will be enforced when an argument is set. Defaults to ``False``. + frozen (bool, optional): if ``True``, the content of the tensorclass cannot be modified. This argument is + provided to dataclass-compatibility, a similar behavior can be obtained through the `lock` argument in + the class constructor. Defaults to ``False``. tensorclass can be used with or without arguments: Examples: @@ -439,11 +447,11 @@ def tensorclass(cls: T) -> T: """ - return _tensorclass_dec(cls) + return _tensorclass_dec(*args, **kwargs) @dataclass_transform() -def _tensorclass(cls: T) -> T: +def _tensorclass(cls: T, *, frozen) -> T: def __torch_function__( cls, func: Callable, @@ -479,7 +487,7 @@ def __torch_function__( _is_non_tensor = getattr(cls, "_is_non_tensor", False) - cls = dataclass(cls) + cls = dataclass(cls, frozen=frozen) expected_keys = cls.__expected_keys__ = set(cls.__dataclass_fields__) for attr in expected_keys: @@ -494,7 +502,7 @@ def __torch_function__( delattr(cls, field.name) _get_type_hints(cls) - cls.__init__ = _init_wrapper(cls.__init__) + cls.__init__ = _init_wrapper(cls.__init__, frozen) cls._from_tensordict = classmethod(_from_tensordict) cls.from_tensordict = cls._from_tensordict if not hasattr(cls, "__torch_function__"): @@ -657,7 +665,7 @@ def _from_tensordict_with_none(tc, tensordict): ) -def _init_wrapper(__init__: Callable) -> Callable: +def _init_wrapper(__init__: Callable, frozen) -> Callable: init_sig = inspect.signature(__init__) params = list(init_sig.parameters.values()) # drop first entry of params which corresponds to self and isn't passed by the user @@ -670,8 +678,11 @@ def wrapper( batch_size: Sequence[int] | torch.Size | int = None, device: DeviceType | None = None, names: List[str] | None = None, + lock: bool | None = None, **kwargs, ): + if lock is None: + lock = frozen if not is_dynamo_compiling(): # zip not supported by dynamo @@ -726,6 +737,13 @@ def wrapper( for key, value in kwargs.items() } __init__(self, **kwargs) + if frozen: + local_setattr = _setattr_wrapper(self.__setattr__, self.__expected_keys__) + for key, val in kwargs.items(): + local_setattr(self, key, val) + del self.__dict__[key] + if lock: + self._tensordict.lock_() new_params = [ inspect.Parameter("batch_size", inspect.Parameter.KEYWORD_ONLY), diff --git a/test/test_tensorclass.py b/test/test_tensorclass.py index 7bbd69342..2e6dce038 100644 --- a/test/test_tensorclass.py +++ b/test/test_tensorclass.py @@ -559,6 +559,52 @@ class MyDataNested: assert (full_like_tc.y.X == 9).all() assert full_like_tc.z == data.z == z + def test_frozen(self): + + @tensorclass(frozen=True, autocast=True) + class X: + y: torch.Tensor + + x = X(y=1) + assert isinstance(x.y, torch.Tensor) + _ = {x: 0} + assert x.is_locked + with pytest.raises(RuntimeError, match="locked"): + x.y = 0 + + @tensorclass(frozen=False, autocast=True) + class X: + y: torch.Tensor + + x = X(y=1) + assert isinstance(x.y, torch.Tensor) + with pytest.raises(TypeError, match="unhashable"): + _ = {x: 0} + assert not x.is_locked + x.y = 0 + + @tensorclass(frozen=True, autocast=False) + class X: + y: torch.Tensor + + x = X(y="a string!") + assert isinstance(x.y, str) + _ = {x: 0} + assert x.is_locked + with pytest.raises(RuntimeError, match="locked"): + x.y = 0 + + @tensorclass(frozen=False, autocast=False) + class X: + y: torch.Tensor + + x = X(y="a string!") + assert isinstance(x.y, str) + with pytest.raises(TypeError, match="unhashable"): + _ = {x: 0} + assert not x.is_locked + x.y = 0 + @pytest.mark.parametrize("from_torch", [True, False]) def test_gather(self, from_torch): @tensorclass