Skip to content

Commit

Permalink
[BugFix] Fix inheritance from non-tensor (#709)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens committed Mar 24, 2024
1 parent 059f539 commit 3b895f6
Show file tree
Hide file tree
Showing 4 changed files with 237 additions and 21 deletions.
4 changes: 2 additions & 2 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ class TensorDictBase(MutableMapping):
is_meta: bool = False
_is_locked: bool = False
_cache: bool = None
_non_tensor: bool = False
_is_non_tensor: bool = False

def __bool__(self) -> bool:
raise RuntimeError("Converting a tensordict to boolean value is not permitted")
Expand Down Expand Up @@ -5414,7 +5414,7 @@ def _default_is_leaf(cls: Type) -> bool:

def _is_leaf_nontensor(cls: Type) -> bool:
if _is_tensor_collection(cls):
return cls._non_tensor
return cls._is_non_tensor
# if issubclass(cls, KeyedJaggedTensor):
# return False
return issubclass(cls, torch.Tensor)
14 changes: 8 additions & 6 deletions tensordict/tensorclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,13 +167,13 @@ def __torch_function__(
)
return _from_tensordict_with_copy(tensorclass_instance, result)

_non_tensor = getattr(cls, "_non_tensor", False)
_is_non_tensor = getattr(cls, "_is_non_tensor", False)

cls = dataclass(cls)
expected_keys = set(cls.__dataclass_fields__)

for attr in cls.__dataclass_fields__:
if attr in dir(TensorDict) and attr != "_non_tensor":
if attr in dir(TensorDict) and attr != "_is_non_tensor":
raise AttributeError(
f"Attribute name {attr} can't be used with @tensorclass"
)
Expand Down Expand Up @@ -253,7 +253,9 @@ def __torch_function__(

_register_tensor_class(cls)

cls._non_tensor = _non_tensor
# faster than doing instance checks
cls._is_non_tensor = _is_non_tensor
cls._is_tensorclass = True

return cls

Expand Down Expand Up @@ -1341,7 +1343,7 @@ class NonTensorData:
val: NonTensorData(
data=0,
_metadata=None,
_non_tensor=True,
_is_non_tensor=True,
batch_size=torch.Size([10]),
device=None,
is_shared=False)},
Expand Down Expand Up @@ -1446,7 +1448,7 @@ class NonTensorData:
data: Any
_metadata: dict | None = None

_non_tensor: bool = True
_is_non_tensor: bool = True

def __post_init__(self):
if is_non_tensor(self.data):
Expand Down Expand Up @@ -1775,7 +1777,7 @@ class NonTensorStack(LazyStackedTensorDict):
"""

_non_tensor: bool = True
_is_non_tensor: bool = True

def tolist(self):
"""Extracts the content of a :class:`tensordict.tensorclass.NonTensorStack` in a nested list.
Expand Down
206 changes: 194 additions & 12 deletions tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import collections
import concurrent.futures
import dataclasses
import inspect
import logging

Expand Down Expand Up @@ -49,6 +48,7 @@
unravel_key_list,
unravel_keys,
)

from torch import Tensor
from torch._C import _disabled_torch_function_impl
from torch.nn.parameter import (
Expand All @@ -59,7 +59,6 @@
)
from torch.utils.data._utils.worker import _generate_state


if TYPE_CHECKING:
from tensordict.memmap_deprec import MemmapTensor as _MemmapTensor
from tensordict.tensordict import TensorDictBase
Expand Down Expand Up @@ -659,6 +658,21 @@ def _set_item(tensor: Tensor, index: IndexType, value: Tensor, *, validated) ->
elif isinstance(tensor, KeyedJaggedTensor):
tensor = setitem_keyedjaggedtensor(tensor, index, value)
return tensor
from tensordict.tensorclass import NonTensorData, NonTensorStack

if is_non_tensor(tensor):
if (
isinstance(value, NonTensorData)
and isinstance(tensor, NonTensorData)
and tensor.data == value.data
):
return tensor
elif isinstance(tensor, NonTensorData):
tensor = NonTensorStack.from_nontensordata(tensor)
if tensor.stack_dim != 0:
tensor = NonTensorStack(*tensor.unbind(0), stack_dim=0)
tensor[index] = value
return tensor
else:
tensor[index] = value
return tensor
Expand Down Expand Up @@ -769,12 +783,8 @@ def is_tensorclass(obj: type | Any) -> bool:
return _is_tensorclass(cls)


def _is_tensorclass(cls) -> bool:
return (
dataclasses.is_dataclass(cls)
and "to_tensordict" in cls.__dict__
and "_from_tensordict" in cls.__dict__
)
def _is_tensorclass(cls: type) -> bool:
return getattr(cls, "_is_tensorclass", False)


class implement_for:
Expand Down Expand Up @@ -1506,9 +1516,7 @@ def _expand_to_match_shape(

def _set_max_batch_size(source: T, batch_dims=None):
"""Updates a tensordict with its maximium batch size."""
from tensordict import NonTensorData

tensor_data = [val for val in source.values() if not isinstance(val, NonTensorData)]
tensor_data = [val for val in source.values() if not is_non_tensor(val)]

for val in tensor_data:
from tensordict.base import _is_tensor_collection
Expand Down Expand Up @@ -1587,7 +1595,7 @@ def wrapper(*args, **kwargs):
def _broadcast_tensors(index):
# tensors and range need to be broadcast
tensors = {
i: tensor if isinstance(tensor, Tensor) else torch.tensor(tensor)
i: torch.as_tensor(tensor)
for i, tensor in enumerate(index)
if isinstance(tensor, (range, list, np.ndarray, Tensor))
}
Expand Down Expand Up @@ -1919,6 +1927,85 @@ def format_size(size):
logging.info(indent + os.path.basename(path))


def isin(
input: TensorDictBase,
reference: TensorDictBase,
key: NestedKey,
dim: int = 0,
) -> Tensor:
"""Tests if each element of ``key`` in input ``dim`` is also present in the reference.
This function returns a boolean tensor of length ``input.batch_size[dim]`` that is ``True`` for elements in
the entry ``key`` that are also present in the ``reference``. This function assumes that both ``input`` and
``reference`` have the same batch size and contain the specified entry, otherwise an error will be raised.
Args:
input (TensorDictBase): Input TensorDict.
reference (TensorDictBase): Target TensorDict against which to test.
key (Nestedkey): The key to test.
dim (int, optional): The dimension along which to test. Defaults to ``0``.
Returns:
out (Tensor): A boolean tensor of length ``input.batch_size[dim]`` that is ``True`` for elements in
the ``input`` ``key`` tensor that are also present in the ``reference``.
Examples:
>>> td = TensorDict(
... {
... "tensor1": torch.tensor([[1, 2, 3], [4, 5, 6], [1, 2, 3], [7, 8, 9]]),
... "tensor2": torch.tensor([[10, 20], [30, 40], [40, 50], [50, 60]]),
... },
... batch_size=[4],
... )
>>> td_ref = TensorDict(
... {
... "tensor1": torch.tensor([[1, 2, 3], [4, 5, 6], [10, 11, 12]]),
... "tensor2": torch.tensor([[10, 20], [30, 40], [50, 60]]),
... },
... batch_size=[3],
... )
>>> in_reference = isin(td, td_ref, key="tensor1")
>>> expected_in_reference = torch.tensor([True, True, True, False])
>>> torch.testing.assert_close(in_reference, expected_in_reference)
"""
# Get the data
reference_tensor = reference.get(key, default=None)
target_tensor = input.get(key, default=None)

# Check key is present in both tensordict and reference_tensordict
if not isinstance(target_tensor, torch.Tensor):
raise KeyError(f"Key '{key}' not found in input or not a tensor.")
if not isinstance(reference_tensor, torch.Tensor):
raise KeyError(f"Key '{key}' not found in reference or not a tensor.")

# Check that both TensorDicts have the same number of dimensions
if len(input.batch_size) != len(reference.batch_size):
raise ValueError(
"The number of dimensions in the batch size of the input and reference must be the same."
)

# Check dim is valid
batch_dims = input.ndim
if dim >= batch_dims or dim < -batch_dims or batch_dims == 0:
raise ValueError(
f"The specified dimension '{dim}' is invalid for an input TensorDict with batch size '{input.batch_size}'."
)

# Convert negative dimension to its positive equivalent
if dim < 0:
dim = batch_dims + dim

# Find the common indices
N = reference_tensor.shape[dim]
cat_data = torch.cat([reference_tensor, target_tensor], dim=dim)
_, unique_indices = torch.unique(
cat_data, dim=dim, sorted=True, return_inverse=True
)
out = torch.isin(unique_indices[N:], unique_indices[:N], assume_unique=True)

return out


def _index_preserve_data_ptr(index):
if isinstance(index, tuple):
return all(_index_preserve_data_ptr(idx) for idx in index)
Expand All @@ -1932,6 +2019,96 @@ def _index_preserve_data_ptr(index):
return False


def remove_duplicates(
input: TensorDictBase,
key: NestedKey,
dim: int = 0,
*,
return_indices: bool = False,
) -> TensorDictBase:
"""Removes indices duplicated in `key` along the specified dimension.
This method detects duplicate elements in the tensor associated with the specified `key` along the specified
`dim` and removes elements in the same indices in all other tensors within the TensorDict. It is expected for
`dim` to be one of the dimensions within the batch size of the input TensorDict to ensure consistency in all
tensors. Otherwise, an error will be raised.
Args:
input (TensorDictBase): The TensorDict containing potentially duplicate elements.
key (NestedKey): The key of the tensor along which duplicate elements should be identified and removed. It
must be one of the leaf keys within the TensorDict, pointing to a tensor and not to another TensorDict.
dim (int, optional): The dimension along which duplicate elements should be identified and removed. It must be one of
the dimensions within the batch size of the input TensorDict. Defaults to ``0``.
return_indices (bool, optional): If ``True``, the indices of the unique elements in the input tensor will be
returned as well. Defaults to ``False``.
Returns:
output (TensorDictBase): input tensordict with the indices corrsponding to duplicated elements
in tensor `key` along dimension `dim` removed.
unique_indices (torch.Tensor, optional): The indices of the first occurrences of the unique elements in the
input tensordict for the specified `key` along the specified `dim`. Only provided if return_index is True.
Example:
>>> td = TensorDict(
... {
... "tensor1": torch.tensor([[1, 2, 3], [4, 5, 6], [1, 2, 3], [7, 8, 9]]),
... "tensor2": torch.tensor([[10, 20], [30, 40], [40, 50], [50, 60]]),
... }
... batch_size=[4],
... )
>>> output_tensordict = remove_duplicate_elements(td, key="tensor1", dim=0)
>>> expected_output = TensorDict(
... {
... "tensor1": torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]),
... "tensor2": torch.tensor([[10, 20], [30, 40], [50, 60]]),
... },
... batch_size=[3],
... )
>>> assert (td == expected_output).all()
"""
tensor = input.get(key, default=None)

# Check if the key is a TensorDict
if tensor is None:
raise KeyError(f"The key '{key}' does not exist in the TensorDict.")

# Check that the key points to a tensor
if not isinstance(tensor, torch.Tensor):
raise KeyError(f"The key '{key}' does not point to a tensor in the TensorDict.")

# Check dim is valid
batch_dims = input.ndim
if dim >= batch_dims or dim < -batch_dims or batch_dims == 0:
raise ValueError(
f"The specified dimension '{dim}' is invalid for a TensorDict with batch size '{input.batch_size}'."
)

# Convert negative dimension to its positive equivalent
if dim < 0:
dim = batch_dims + dim

# Get indices of unique elements (e.g. [0, 1, 0, 2])
_, unique_indices, counts = torch.unique(
tensor, dim=dim, sorted=True, return_inverse=True, return_counts=True
)

# Find first occurrence of each index (e.g. [0, 1, 3])
_, unique_indices_sorted = torch.sort(unique_indices, stable=True)
cum_sum = counts.cumsum(0, dtype=torch.long)
cum_sum = torch.cat(
(torch.zeros(1, device=input.device, dtype=torch.long), cum_sum[:-1])
)
first_indices = unique_indices_sorted[cum_sum]

# Remove duplicate elements in the TensorDict
output = input[(slice(None),) * dim + (first_indices,)]

if return_indices:
return output, unique_indices

return output


class _CloudpickleWrapper(object):
def __init__(self, fn):
self.fn = fn
Expand Down Expand Up @@ -1987,3 +2164,8 @@ def __call__(self, mod: torch.nn.Module, args, kwargs):
return
else:
raise RuntimeError("did not find pre-hook")


def is_non_tensor(data):
"""Checks if an item is a non-tensor."""
return getattr(type(data), "_is_non_tensor", False)
34 changes: 33 additions & 1 deletion test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,14 @@
from tensordict.memmap import MemoryMappedTensor

from tensordict.nn import TensorDictParams
from tensordict.tensorclass import NonTensorData
from tensordict.tensorclass import NonTensorData, tensorclass
from tensordict.utils import (
_getitem_batch_size,
_LOCK_ERROR,
assert_allclose_td,
convert_ellipsis_to_idx,
is_non_tensor,
is_tensorclass,
lazy_legacy,
set_lazy_legacy,
)
Expand Down Expand Up @@ -8170,6 +8172,36 @@ def test_shared_stack(self, strategy, update, tmpdir):
assert TensorDict.load_memmap(tmpdir).get("val").tolist() == [0, 3] * 5


class TestSubclassing:
def test_td_inheritance(self):
class SubTD(TensorDict):
...

assert is_tensor_collection(SubTD)

def test_tc_inheritance(self):
@tensorclass
class MyClass:
...

assert is_tensor_collection(MyClass)
assert is_tensorclass(MyClass)

class SubTC(MyClass):
...

assert is_tensor_collection(SubTC)
assert is_tensorclass(SubTC)

def test_nontensor_inheritance(self):
class SubTC(NonTensorData):
...

assert is_tensor_collection(SubTC)
assert is_tensorclass(SubTC)
assert is_non_tensor(SubTC(data=1, batch_size=[]))


if __name__ == "__main__":
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

0 comments on commit 3b895f6

Please sign in to comment.