Skip to content

Commit

Permalink
[Feature] TD+NJT to(device) support
Browse files Browse the repository at this point in the history
ghstack-source-id: 71812497f1efb9d20f67a7561e74d5111c4cc3f0
Pull Request resolved: #1022
  • Loading branch information
vmoens committed Oct 2, 2024
1 parent 5a50f89 commit 7528728
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 3 deletions.
14 changes: 14 additions & 0 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10346,6 +10346,20 @@ def _to_consolidated(self, *, device, pin_memory, num_threads, non_blocking):
untyped_storage = storage_cast.untyped_storage()

def set_(x):
if x.is_nested:
if x.layout != torch.jagged:
raise RuntimeError(
"to(device) with nested tensors that do not have a jagged layout is not implemented yet. "
"Please raise an issue on GitHub."
)
values = x._values
lengths = x._lengths
offsets = x._offsets
return torch.nested.nested_tensor_from_jagged(
set_(values),
offsets=set_(offsets),
lengths=set_(lengths) if lengths is not None else None,
)
storage_offset = x.storage_offset()
stride = x.stride()
return torch.empty_like(x, device=device).set_(
Expand Down
11 changes: 8 additions & 3 deletions tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1542,9 +1542,14 @@ def assert_close(
elif not isinstance(input1, torch.Tensor):
continue
if input1.is_nested:
input1 = input1._base
input2 = input2._base
mse = (input1.to(torch.float) - input2.to(torch.float)).pow(2).sum()
input1v = input1.values()
input2v = input2.values()
mse = (input1v.to(torch.float) - input2v.to(torch.float)).pow(2).sum()
input1o = input1.offsets()
input2o = input2.offsets()
mse = mse + (input1o.to(torch.float) - input2o.to(torch.float)).pow(2).sum()
else:
mse = (input1.to(torch.float) - input2.to(torch.float)).pow(2).sum()
mse = mse.div(input1.numel()).sqrt().item()

local_msg = f"key {key} does not match, got mse = {mse:4.4f}"
Expand Down
41 changes: 41 additions & 0 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -7963,6 +7963,47 @@ def test_consolidate_to_device(self):
assert td_c_device["d"] == [["a string!"] * 3]
assert len(dataptrs) == 1

@pytest.mark.skipif(not torch.cuda.is_available(), reason="no cuda device detected")
def test_consolidate_to_device_njt(self):
td = TensorDict(
{
"a": torch.arange(3).expand(4, 3).clone(),
"d": "a string!",
"njt": torch.nested.nested_tensor_from_jagged(
torch.arange(10), offsets=torch.tensor([0, 2, 5, 8, 10])
),
"njt_lengths": torch.nested.nested_tensor_from_jagged(
torch.arange(10),
offsets=torch.tensor([0, 2, 5, 8, 10]),
lengths=torch.tensor([2, 3, 3, 2]),
),
},
device="cpu",
batch_size=[4],
)
device = torch.device("cuda:0")
td_c = td.consolidate()
assert td_c.device == torch.device("cpu")
td_c_device = td_c.to(device)
assert td_c_device.device == device
assert td_c_device.is_consolidated()
dataptrs = set()
for tensor in td_c_device.values(True, True, is_leaf=_NESTED_TENSORS_AS_LISTS):
assert tensor.device == device
if tensor.is_nested:
vals = tensor._values
dataptrs.add(vals.untyped_storage().data_ptr())
offsets = tensor._offsets
dataptrs.add(offsets.untyped_storage().data_ptr())
lengths = tensor._lengths
if lengths is not None:
dataptrs.add(lengths.untyped_storage().data_ptr())
else:
dataptrs.add(tensor.untyped_storage().data_ptr())
assert len(dataptrs) == 1
assert assert_allclose_td(td_c_device.cpu(), td)
assert td_c_device["njt_lengths"]._lengths is not None

def test_create_empty(self):
td = LazyStackedTensorDict(stack_dim=0)
assert td.device is None
Expand Down

0 comments on commit 7528728

Please sign in to comment.