Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Stacking tensors of different shape #135

Merged
merged 4 commits into from
Dec 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 33 additions & 16 deletions tensordict/metatensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def __new__(cls, *args, **kwargs):
@property
def shape(self):
_shape = self._shape
if _shape is None:
if _shape is None and self._tensor is not None:
_shape = self._shape = _shape_fn(self._tensor)
return _shape

Expand All @@ -114,15 +114,16 @@ def device(self):
@property
def dtype(self):
_dtype = self._dtype
if _dtype is None and not self.is_tensordict():
if _dtype is None and not self.is_tensordict() and self._tensor is not None:
_dtype = self._dtype = _dtype_fn(self._tensor)
return _dtype

def is_tensordict(self):
_is_tensordict = self._is_tensordict
if _is_tensordict is None:
_is_tensordict = self._is_tensordict = (
not isinstance(self._tensor, torch.Tensor)
self._tensor is not None
and not isinstance(self._tensor, torch.Tensor)
and not self.is_memmap()
and not self.is_kjt()
)
Expand Down Expand Up @@ -165,15 +166,20 @@ def __init__(
_is_kjt: Optional[bool] = None,
_repr_tensordict: Optional[str] = None,
):
tensor = None
if len(shape) == 1 and not isinstance(shape[0], (Number,)):
if (
len(shape) == 1
and not isinstance(shape[0], (Number,))
and shape[0] is not None
):
tensor = shape[0]
self._tensor = tensor
return

if type(shape) is not torch.Size:
shape = torch.Size(shape)
self.shape = shape
elif len(shape) == 1 and shape[0] is None:
self.shape = None
else:
if type(shape) is not torch.Size:
shape = torch.Size(shape)
self.shape = shape
self._device = device
self._dtype = dtype if dtype is not None else torch.get_default_dtype()
self._ndim = len(shape)
Expand All @@ -196,7 +202,7 @@ def class_name(self):
name = "MemmapTensor"
elif self._is_kjt:
name = "KeyedJaggedTensor"
elif self.is_shared() and self.device.type != "cuda":
elif self.is_shared() and self.device and self.device.type != "cuda":
name = "SharedTensor"
else:
name = "Tensor"
Expand All @@ -207,7 +213,10 @@ def get_repr(self):
if self.is_tensordict():
return repr(self._tensor)
else:
return f"{self.class_name}({self.shape}, dtype={self.dtype})"
shape = self.shape
if shape is None:
shape = "*"
return f"{self.class_name}({shape}, dtype={self.dtype})"

def memmap_(self) -> MetaTensor:
"""Changes the storage of the MetaTensor to memmap.
Expand Down Expand Up @@ -242,9 +251,9 @@ def share_memory_(self) -> MetaTensor:
return self

def is_shared(self) -> bool:
if self._is_shared is None:
if self._is_shared is None and self._tensor is not None:
self._is_shared = self._tensor.is_shared()
return self._is_shared
return bool(self._is_shared)

def numel(self) -> int:
if self._numel is None:
Expand Down Expand Up @@ -428,9 +437,17 @@ def _stack_meta(
f"Stacking meta tensors of different dtype is not "
f"allowed, got shapes {dtype} and {tensor.dtype}"
)

shape = list(shape)
shape.insert(dim, len(list_of_meta_tensors))
for tensor in list_of_meta_tensors:
if tensor.shape != shape:
shape = (None,)
break
else:
shape = list(shape)
shape.insert(dim, len(list_of_meta_tensors))
if dtype is None:
dtype = list_of_meta_tensors[0].dtype
if device is None:
device = list_of_meta_tensors[0].device

return MetaTensor(
*shape,
Expand Down
68 changes: 56 additions & 12 deletions tensordict/tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2128,6 +2128,7 @@ def lock(self):
for key, item in self.items_meta():
if item.is_tensordict():
self.get(key).lock()
return self

def unlock(self):
self._is_locked = False
Expand All @@ -2136,6 +2137,7 @@ def unlock(self):
for key, item in self.items_meta():
if item.is_tensordict():
self.get(key).unlock()
return self


class TensorDict(TensorDictBase):
Expand Down Expand Up @@ -3131,7 +3133,14 @@ def _cat(
key, "Attempted to concatenate tensors on different devices at key"
):
out[key] = torch.cat([td.get(key) for td in list_of_tensordicts], dim)

if device is None:
device = list_of_tensordicts[0].device
for td in list_of_tensordicts[1:]:
if device == td.device:
continue
else:
device = None
break
return TensorDict(out, device=device, batch_size=batch_size, _run_checks=False)
else:
if out.batch_size != batch_size:
Expand Down Expand Up @@ -4174,15 +4183,50 @@ def get(
return self._default_get(key, default)

tensors = [td.get(key, default=default) for td in self.tensordicts]
shapes = {_shape(tensor) for tensor in tensors}
if len(shapes) != 1:
raise RuntimeError(
f"found more than one unique shape in the tensors to be "
f"stacked ({shapes}). This is likely due to a modification "
f"of one of the stacked TensorDicts, where a key has been "
f"updated/created with an uncompatible shape."
)
return torch.stack(tensors, self.stack_dim)
try:
return torch.stack(tensors, self.stack_dim)
except RuntimeError as err:
if "stack expects each tensor to be equal size" in str(err):
shapes = {_shape(tensor) for tensor in tensors}
raise RuntimeError(
f"Found more than one unique shape in the tensors to be "
f"stacked ({shapes}). This is likely due to a modification "
f"of one of the stacked TensorDicts, where a key has been "
f"updated/created with an uncompatible shape. If the entries "
f"are intended to have a different shape, use the get_nestedtensor "
f"method instead."
)
else:
raise err

def get_nestedtensor(
self,
key: NESTED_KEY,
default: Union[str, COMPATIBLE_TYPES] = "_no_default_",
) -> COMPATIBLE_TYPES:
# TODO: the stacking logic below works for nested keys, but the key in
# self.valid_keys check will fail and we'll return the default instead.
# For now we'll advise user that nested keys aren't supported, but it should be
# fairly easy to add support if we could add nested keys to valid_keys.

# we can handle the case where the key is a tuple of length 1
if (type(key) is tuple) and len(key) == 1:
key = key[0]
elif type(key) is tuple:
tensordict, key = _get_leaf_tensordict(self, key)
return tensordict.get_nestedtensor(key)

keys = self.valid_keys
if not (key in keys):
# first, let's try to update the valid keys
self._update_valid_keys()
keys = self.valid_keys

if not (key in keys):
return self._default_get(key, default)

tensors = [td.get(key, default=default) for td in self.tensordicts]
return torch.nested.nested_tensor(tensors)

def _make_meta(self, key: str) -> MetaTensor:
return torch.stack(
Expand All @@ -4193,7 +4237,7 @@ def is_contiguous(self) -> bool:
return False

def contiguous(self) -> TensorDictBase:
source = {key: value for key, value in self.items()}
source = {key: value.contiguous() for key, value in self.items()}
batch_size = self.batch_size
device = self.device
out = TensorDict(
Expand Down Expand Up @@ -5558,7 +5602,7 @@ def _stack_onto_(

def _make_repr(key, item: MetaTensor, tensordict):
if item.is_tensordict():
return f"{key}: {repr(tensordict[key])}"
return f"{key}: {repr(tensordict.get(key))}"
return f"{key}: {item.get_repr()}"


Expand Down
39 changes: 36 additions & 3 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,13 +285,13 @@ def test_cat_td(device):
"key2": torch.randn(4, 5, 10, device=device),
"key3": {"key4": torch.randn(4, 5, 10, device=device)},
}
td1 = TensorDict(batch_size=(4, 5), source=d)
td1 = TensorDict(batch_size=(4, 5), source=d, device=device)
d = {
"key1": torch.randn(4, 10, 6, device=device),
"key2": torch.randn(4, 10, 10, device=device),
"key3": {"key4": torch.randn(4, 10, 10, device=device)},
}
td2 = TensorDict(batch_size=(4, 10), source=d)
td2 = TensorDict(batch_size=(4, 10), source=d, device=device)

td_cat = torch.cat([td1, td2], 1)
assert td_cat.batch_size == torch.Size([4, 15])
Expand All @@ -300,7 +300,7 @@ def test_cat_td(device):
"key2": torch.zeros(4, 15, 10, device=device),
"key3": {"key4": torch.zeros(4, 15, 10, device=device)},
}
td_out = TensorDict(batch_size=(4, 15), source=d)
td_out = TensorDict(batch_size=(4, 15), source=d, device=device)
torch.cat([td1, td2], 1, out=td_out)
assert td_out.batch_size == torch.Size([4, 15])
assert (td_out["key1"] != 0).all()
Expand Down Expand Up @@ -1224,6 +1224,39 @@ def test_inferred_view_size(self, td_name, device):
assert td.view(-1).view(*new_shape) is td
assert td.view(*new_shape) is td

@pytest.mark.parametrize("dim", [0, 1, -1])
@pytest.mark.parametrize(
"key", ["heterogeneous-entry", ("sub", "heterogeneous-entry")]
)
def test_nestedtensor_stack(self, td_name, device, dim, key):
torch.manual_seed(1)
td1 = getattr(self, td_name)(device).unlock()
td2 = getattr(self, td_name)(device).unlock()
td1[key] = torch.randn(*td1.shape, 2)
td2[key] = torch.randn(*td1.shape, 3)
td_stack = torch.stack([td1, td2], dim)
# get will fail
with pytest.raises(
RuntimeError, match="Found more than one unique shape in the tensors"
):
td_stack.get(key)
with pytest.raises(
RuntimeError, match="Found more than one unique shape in the tensors"
):
td_stack[key]
# this will work: it is the proper way to get that entry
td_stack.get_nestedtensor(key)
with pytest.raises(
RuntimeError, match="Found more than one unique shape in the tensors"
):
td_stack.contiguous()
with pytest.raises(
RuntimeError, match="Found more than one unique shape in the tensors"
):
td_stack.to_tensordict()
# cloning is type-preserving: we can do that operation
td_stack.clone()

def test_clone_td(self, td_name, device):
torch.manual_seed(1)
td = getattr(self, td_name)(device)
Expand Down