Skip to content

Commit

Permalink
[Feature] tensordict.transpose (#467)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Jun 29, 2023
1 parent c699d6c commit 128f42a
Show file tree
Hide file tree
Showing 3 changed files with 212 additions and 20 deletions.
5 changes: 5 additions & 0 deletions tensordict/persistent.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,11 @@ def fill_(self, key: str, value: float | bool) -> TensorDictBase:
nested.fill_(subkey, value)
return self

def _create_nested_str(self, key):
self.file.create_group(key)
target_td = self._get_str(key)
return target_td

def select(
self, *keys: str, inplace: bool = False, strict: bool = True
) -> PersistentTensorDict:
Expand Down
188 changes: 168 additions & 20 deletions tensordict/tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2501,6 +2501,35 @@ def cuda(self, device: int = None) -> TensorDictBase:
return self.to(torch.device("cuda"))
return self.to(f"cuda:{device}")

def _create_nested_str(self, key):
self.set(key, self.select())

def _create_nested_tuple(self, key):
self._create_nested_str(key[0])
if len(key) > 1:
td = self._get_str(key[0], NO_DEFAULT)
td._create_nested_tuple(key[1:])

@lock_blocked
def create_nested(self, key):
"""Creates a nested tensordict of the same shape, device and dim names as the current tensordict.
If the value already exists, it will be overwritten by this operation.
This operation is blocked in locked tensordicts.
Examples:
>>> data = TensorDict({}, [3, 4, 5])
>>> data.create_nested("root")
>>> data.create_nested(("some", "nested", "value"))
>>> nested = data.get(("some", "nested", "value"))
"""
key = unravel_keys(key)
if isinstance(key, str):
self._create_nested_str(key)
else:
self._create_nested_tuple(key)
return self

@abc.abstractmethod
def masked_fill_(self, mask: Tensor, value: float | bool) -> TensorDictBase:
"""Fills the values corresponding to the mask with the desired value.
Expand Down Expand Up @@ -2868,6 +2897,40 @@ def view(
inv_op_kwargs={"size": self.batch_size},
)

def transpose(self, dim0, dim1):
"""Returns a tensordit that is a transposed version of input. The given dimensions ``dim0`` and ``dim1`` are swapped.
In-place or out-place modifications of the transposed tensordict will
impact the original tensordict too as the memory is shared and the operations
are mapped back on the original tensordict.
Examples:
>>> tensordict = TensorDict({"a": torch.randn(3, 4, 5)}, [3, 4])
>>> tensordict_transpose = tensordict.transpose(0, 1)
>>> print(tensordict_transpose.shape)
torch.Size([4, 3])
>>> tensordict_transpose.set("b",, torch.randn(4, 3))
>>> print(tensordict.get("b").shape)
torch.Size([3, 4])
"""
if dim0 < 0:
dim0 = self.ndim + dim0
if dim1 < 0:
dim1 = self.ndim + dim1
if any((dim0 < 0, dim1 < 0)):
raise ValueError(
"The provided dimensions are incompatible with the tensordict batch-size."
)
if dim0 == dim1:
return self
return _TransposedTensorDict(
source=self,
custom_op="transpose",
inv_op="transpose",
custom_op_kwargs={"dim0": dim0, "dim1": dim1},
inv_op_kwargs={"dim0": dim0, "dim1": dim1},
)

def permute(
self,
*dims_list: int,
Expand Down Expand Up @@ -4155,7 +4218,7 @@ def _get_str(self, key, default):
def _get_tuple(self, key, default):
first = self._get_str(key[0], None)
if first is None:
return self._default_get(first, default)
return self._default_get(key[0], default)
if len(key) == 1:
return first
try:
Expand Down Expand Up @@ -5154,7 +5217,7 @@ def __init__(
idx: IndexType,
batch_size: Sequence[int] | None = None,
) -> None:
if not isinstance(source, TensorDictBase):
if not _is_tensor_collection(source.__class__):
raise TypeError(
f"Expected source to be a subclass of TensorDictBase, "
f"got {type(source)}"
Expand Down Expand Up @@ -5205,15 +5268,7 @@ def _convert_ellipsis(idx, shape):
def exclude(self, *keys: str, inplace: bool = False) -> TensorDictBase:
if inplace:
return super().exclude(*keys, inplace=True)
return TensorDict(
{key: value for key, value in self.items()},
batch_size=self.batch_size,
device=self.device,
names=self._td_dim_names,
_run_checks=False,
_is_memmap=self.is_memmap(),
_is_shared=self.is_shared(),
).exclude(*keys, inplace=True)
return self.to_tensordict().exclude(*keys, inplace=True)

@property
def batch_size(self) -> torch.Size:
Expand Down Expand Up @@ -5311,12 +5366,16 @@ def set(
) -> TensorDictBase:
key = self._validate_key(key)

if isinstance(key, tuple):
parent = self.get_parent_tensordict()
subparent, subkey = _get_leaf_tensordict(parent, key, _default_hook)
subparent.get_sub_tensordict(self.idx).set(subkey, tensor, inplace=inplace)
if isinstance(key, tuple) and len(key) > 1:
source = self._source
td = source._get_str(key[0], None)
if td is None:
source.create_nested(key[0])
td = self._get_str(key[0], NO_DEFAULT)
td.set(key[1:], tensor, inplace)
return self

elif isinstance(key, tuple):
key = key[0]
key_present = key in self.keys()
inplace = inplace and key_present
if not inplace:
Expand Down Expand Up @@ -5419,6 +5478,8 @@ def get(
return self._source.get_at(key, self.idx, default=default)

def _get_str(self, key, default):
if key in self.keys() and _is_tensor_collection(self.entry_class(key)):
return SubTensorDict(self._source._get_str(key, NO_DEFAULT), self.idx)
return self._source._get_at_str(key, self.idx, default=default)

def _get_tuple(self, key, default):
Expand Down Expand Up @@ -7354,19 +7415,24 @@ def set(
if self.is_locked:
raise RuntimeError(TensorDictBase.LOCK_ERROR)

if isinstance(key, tuple):
subsource, subkey = _get_leaf_tensordict(self._source, key, _default_hook)
if isinstance(key, tuple) and len(key) > 1:
subsource = self._source._get_str(key[0], None)
if subsource is None:
self._source.create_nested(key[0])
subsource = self._source._get_str(key[0], NO_DEFAULT)

td = self.__class__(
source=subsource,
custom_op=self.custom_op,
inv_op=self.inv_op,
custom_op_kwargs=self._update_custom_op_kwargs(subsource),
inv_op_kwargs=self._update_inv_op_kwargs(subsource),
)
td.set(subkey, value, inplace=inplace)
td.set(key[1:], value, inplace=inplace)
return self
elif isinstance(key, tuple):
key = key[0]

key = self._validate_key(key)
value = self._validate_value(value)
return self._set(key, value, inplace=inplace)

Expand Down Expand Up @@ -7760,6 +7826,87 @@ def names(self, value):
)


class _TransposedTensorDict(_CustomOpTensorDict):
"""A lazy view on a TensorDict with two batch dimensions transposed.
When calling `tensordict.permute(dims_list, dim)`, a lazy view of this operation is
returned such that the following code snippet works without raising an
exception:
>>> assert tensordict.transpose(dims_list, dim).transpose(dims_list, dim) is tensordict
"""

def transpose(self, dim0, dim1) -> TensorDictBase:
if dim0 < 0:
dim0 = self.ndim + dim0
if dim1 < 0:
dim1 = self.ndim + dim1
if any((dim0 < 0, dim1 < 0)):
raise ValueError(
"The provided dimensions are incompatible with the tensordict batch-size."
)
if dim0 == dim1:
return self
dims = (self.inv_op_kwargs.get("dim0"), self.inv_op_kwargs.get("dim1"))
if dim0 in dims and dim1 in dims:
return self._source
return super().permute(dim0, dim1)

def add_missing_dims(
self, num_dims: int, batch_dims: tuple[int, ...]
) -> tuple[int, ...]:
dim_diff = num_dims - len(batch_dims)
all_dims = list(range(num_dims))
for i, x in enumerate(batch_dims):
if x < 0:
x = x - dim_diff
all_dims[i] = x
return tuple(all_dims)

def _update_custom_op_kwargs(self, source_tensor: Tensor) -> dict[str, Any]:
return self.custom_op_kwargs

def _update_inv_op_kwargs(self, tensor: Tensor) -> dict[str, Any]:
return self.custom_op_kwargs

def _stack_onto_(
self,
key: str,
list_item: list[CompatibleType],
dim: int,
) -> TensorDictBase:

trsp = self.custom_op_kwargs["dim0"], self.custom_op_kwargs["dim1"]
if dim == trsp[0]:
dim = trsp[1]
elif dim == trsp[1]:
dim = trsp[0]

list_permuted_items = []
for item in list_item:
list_permuted_items.append(item.transpose(*trsp))
self._source._stack_onto_(key, list_permuted_items, dim)
return self

@property
def names(self):
names = copy(self._source.names)
dim0 = self.custom_op_kwargs["dim0"]
dim1 = self.custom_op_kwargs["dim1"]
names = [
names[dim0] if i == dim1 else names[dim1] if i == dim0 else name
for i, name in enumerate(names)
]
return names

@names.setter
def names(self, value):
raise RuntimeError(
"Names of a lazy tensordict cannot be modified. Call to_tensordict() first."
)


class _PermutedTensorDict(_CustomOpTensorDict):
"""A lazy view on a TensorDict with the batch dimensions permuted.
Expand Down Expand Up @@ -7805,6 +7952,7 @@ def permute(
def add_missing_dims(
self, num_dims: int, batch_dims: tuple[int, ...]
) -> tuple[int, ...]:
# Adds the feature dimensions to the permute dims
dim_diff = num_dims - len(batch_dims)
all_dims = list(range(num_dims))
for i, x in enumerate(batch_dims):
Expand Down
39 changes: 39 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -1707,6 +1707,45 @@ def test_setitem_nested_dict_value(self, td_name, device):
td_clone2["d"] = nested_tensordict_value
assert (td_clone1 == td_clone2).all()

def test_transpose(self, td_name, device):
td = getattr(self, td_name)(device)
tdt = td.transpose(0, 1)
assert tdt.shape == torch.Size([td.shape[1], td.shape[0], *td.shape[2:]])
for key, value in tdt.items(True):
assert value.shape == torch.Size(
[td.get(key).shape[1], td.get(key).shape[0], *td.get(key).shape[2:]]
)
tdt = td.transpose(-1, -2)
for key, value in tdt.items(True):
assert value.shape == td.get(key).transpose(2, 3).shape
assert tdt.transpose(-1, -2) is td
with td.unlock_():
tdt.set(("some", "transposed", "tensor"), torch.zeros(tdt.shape))
assert td.get(("some", "transposed", "tensor")).shape == td.shape
assert td.transpose(0, 0) is td
with pytest.raises(
ValueError, match="The provided dimensions are incompatible"
):
td.transpose(-5, -6)
with pytest.raises(
ValueError, match="The provided dimensions are incompatible"
):
tdt.transpose(-5, -6)

def test_create_nested(self, td_name, device):
td = getattr(self, td_name)(device)
with td.unlock_():
td.create_nested("root")
assert td.get("root").shape == td.shape
assert is_tensor_collection(td.get("root"))
td.create_nested(("some", "nested", "key"))
assert td.get(("some", "nested", "key")).shape == td.shape
assert is_tensor_collection(td.get(("some", "nested", "key")))
if td_name in ("sub_td", "sub_td2"):
return
with td.lock_(), pytest.raises(RuntimeError):
td.create_nested("root")

def test_tensordict_set(self, td_name, device):
torch.manual_seed(1)
np.random.seed(1)
Expand Down

0 comments on commit 128f42a

Please sign in to comment.