From 7528728a19c5987b887eb3b269bdeddddf0ba4e0 Mon Sep 17 00:00:00 2001 From: Vincent Moens Date: Wed, 2 Oct 2024 11:30:20 +0100 Subject: [PATCH] [Feature] TD+NJT to(device) support ghstack-source-id: 71812497f1efb9d20f67a7561e74d5111c4cc3f0 Pull Request resolved: https://github.com/pytorch/tensordict/pull/1022 --- tensordict/base.py | 14 ++++++++++++++ tensordict/utils.py | 11 ++++++++--- test/test_tensordict.py | 41 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 63 insertions(+), 3 deletions(-) diff --git a/tensordict/base.py b/tensordict/base.py index b0a6b3fe0..222e9fe20 100644 --- a/tensordict/base.py +++ b/tensordict/base.py @@ -10346,6 +10346,20 @@ def _to_consolidated(self, *, device, pin_memory, num_threads, non_blocking): untyped_storage = storage_cast.untyped_storage() def set_(x): + if x.is_nested: + if x.layout != torch.jagged: + raise RuntimeError( + "to(device) with nested tensors that do not have a jagged layout is not implemented yet. " + "Please raise an issue on GitHub." + ) + values = x._values + lengths = x._lengths + offsets = x._offsets + return torch.nested.nested_tensor_from_jagged( + set_(values), + offsets=set_(offsets), + lengths=set_(lengths) if lengths is not None else None, + ) storage_offset = x.storage_offset() stride = x.stride() return torch.empty_like(x, device=device).set_( diff --git a/tensordict/utils.py b/tensordict/utils.py index fd3140401..dae4fa95f 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -1542,9 +1542,14 @@ def assert_close( elif not isinstance(input1, torch.Tensor): continue if input1.is_nested: - input1 = input1._base - input2 = input2._base - mse = (input1.to(torch.float) - input2.to(torch.float)).pow(2).sum() + input1v = input1.values() + input2v = input2.values() + mse = (input1v.to(torch.float) - input2v.to(torch.float)).pow(2).sum() + input1o = input1.offsets() + input2o = input2.offsets() + mse = mse + (input1o.to(torch.float) - input2o.to(torch.float)).pow(2).sum() + else: + mse = (input1.to(torch.float) - input2.to(torch.float)).pow(2).sum() mse = mse.div(input1.numel()).sqrt().item() local_msg = f"key {key} does not match, got mse = {mse:4.4f}" diff --git a/test/test_tensordict.py b/test/test_tensordict.py index ee12d969d..e185be174 100644 --- a/test/test_tensordict.py +++ b/test/test_tensordict.py @@ -7963,6 +7963,47 @@ def test_consolidate_to_device(self): assert td_c_device["d"] == [["a string!"] * 3] assert len(dataptrs) == 1 + @pytest.mark.skipif(not torch.cuda.is_available(), reason="no cuda device detected") + def test_consolidate_to_device_njt(self): + td = TensorDict( + { + "a": torch.arange(3).expand(4, 3).clone(), + "d": "a string!", + "njt": torch.nested.nested_tensor_from_jagged( + torch.arange(10), offsets=torch.tensor([0, 2, 5, 8, 10]) + ), + "njt_lengths": torch.nested.nested_tensor_from_jagged( + torch.arange(10), + offsets=torch.tensor([0, 2, 5, 8, 10]), + lengths=torch.tensor([2, 3, 3, 2]), + ), + }, + device="cpu", + batch_size=[4], + ) + device = torch.device("cuda:0") + td_c = td.consolidate() + assert td_c.device == torch.device("cpu") + td_c_device = td_c.to(device) + assert td_c_device.device == device + assert td_c_device.is_consolidated() + dataptrs = set() + for tensor in td_c_device.values(True, True, is_leaf=_NESTED_TENSORS_AS_LISTS): + assert tensor.device == device + if tensor.is_nested: + vals = tensor._values + dataptrs.add(vals.untyped_storage().data_ptr()) + offsets = tensor._offsets + dataptrs.add(offsets.untyped_storage().data_ptr()) + lengths = tensor._lengths + if lengths is not None: + dataptrs.add(lengths.untyped_storage().data_ptr()) + else: + dataptrs.add(tensor.untyped_storage().data_ptr()) + assert len(dataptrs) == 1 + assert assert_allclose_td(td_c_device.cpu(), td) + assert td_c_device["njt_lengths"]._lengths is not None + def test_create_empty(self): td = LazyStackedTensorDict(stack_dim=0) assert td.device is None