Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Fix inheritance from non-tensor #709

Merged
merged 2 commits into from
Mar 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ class TensorDictBase(MutableMapping):
is_meta: bool = False
_is_locked: bool = False
_cache: bool = None
_non_tensor: bool = False
_is_non_tensor: bool = False

def __bool__(self) -> bool:
raise RuntimeError("Converting a tensordict to boolean value is not permitted")
Expand Down Expand Up @@ -5414,7 +5414,7 @@ def _default_is_leaf(cls: Type) -> bool:

def _is_leaf_nontensor(cls: Type) -> bool:
if _is_tensor_collection(cls):
return cls._non_tensor
return cls._is_non_tensor
# if issubclass(cls, KeyedJaggedTensor):
# return False
return issubclass(cls, torch.Tensor)
14 changes: 8 additions & 6 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,13 +167,13 @@ def __torch_function__(
)
return _from_tensordict_with_copy(tensorclass_instance, result)

_non_tensor = getattr(cls, "_non_tensor", False)
_is_non_tensor = getattr(cls, "_is_non_tensor", False)

cls = dataclass(cls)
expected_keys = set(cls.__dataclass_fields__)

for attr in cls.__dataclass_fields__:
if attr in dir(TensorDict) and attr != "_non_tensor":
if attr in dir(TensorDict) and attr != "_is_non_tensor":
raise AttributeError(
f"Attribute name {attr} can't be used with @tensorclass"
)
Expand Down Expand Up @@ -253,7 +253,9 @@ def __torch_function__(

_register_tensor_class(cls)

cls._non_tensor = _non_tensor
# faster than doing instance checks
cls._is_non_tensor = _is_non_tensor
cls._is_tensorclass = True

return cls

Expand Down Expand Up @@ -1341,7 +1343,7 @@ class NonTensorData:
val: NonTensorData(
data=0,
_metadata=None,
_non_tensor=True,
_is_non_tensor=True,
batch_size=torch.Size([10]),
device=None,
is_shared=False)},
Expand Down Expand Up @@ -1446,7 +1448,7 @@ class NonTensorData:
data: Any
_metadata: dict | None = None

_non_tensor: bool = True
_is_non_tensor: bool = True

def __post_init__(self):
if is_non_tensor(self.data):
Expand Down Expand Up @@ -1775,7 +1777,7 @@ class NonTensorStack(LazyStackedTensorDict):

"""

_non_tensor: bool = True
_is_non_tensor: bool = True

def tolist(self):
"""Extracts the content of a :class:`tensordict.tensorclass.NonTensorStack` in a nested list.
Expand Down
11 changes: 3 additions & 8 deletions tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import collections
import concurrent.futures
import dataclasses
import inspect
import logging

Expand Down Expand Up @@ -784,12 +783,8 @@ def is_tensorclass(obj: type | Any) -> bool:
return _is_tensorclass(cls)


def _is_tensorclass(cls) -> bool:
return (
dataclasses.is_dataclass(cls)
and "to_tensordict" in cls.__dict__
and "_from_tensordict" in cls.__dict__
)
def _is_tensorclass(cls: type) -> bool:
return getattr(cls, "_is_tensorclass", False)


class implement_for:
Expand Down Expand Up @@ -2173,4 +2168,4 @@ def __call__(self, mod: torch.nn.Module, args, kwargs):

def is_non_tensor(data):
"""Checks if an item is a non-tensor."""
return type(data).__dict__.get("_non_tensor", False)
return getattr(type(data), "_is_non_tensor", False)
34 changes: 33 additions & 1 deletion test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,14 @@
from tensordict.memmap import MemoryMappedTensor

from tensordict.nn import TensorDictParams
from tensordict.tensorclass import NonTensorData
from tensordict.tensorclass import NonTensorData, tensorclass
from tensordict.utils import (
_getitem_batch_size,
_LOCK_ERROR,
assert_allclose_td,
convert_ellipsis_to_idx,
is_non_tensor,
is_tensorclass,
lazy_legacy,
set_lazy_legacy,
)
Expand Down Expand Up @@ -8170,6 +8172,36 @@ def test_shared_stack(self, strategy, update, tmpdir):
assert TensorDict.load_memmap(tmpdir).get("val").tolist() == [0, 3] * 5


class TestSubclassing:
def test_td_inheritance(self):
class SubTD(TensorDict):
...

assert is_tensor_collection(SubTD)

def test_tc_inheritance(self):
@tensorclass
class MyClass:
...

assert is_tensor_collection(MyClass)
assert is_tensorclass(MyClass)

class SubTC(MyClass):
...

assert is_tensor_collection(SubTC)
assert is_tensorclass(SubTC)

def test_nontensor_inheritance(self):
class SubTC(NonTensorData):
...

assert is_tensor_collection(SubTC)
assert is_tensorclass(SubTC)
assert is_non_tensor(SubTC(data=1, batch_size=[]))


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
Loading