diff --git a/monai/data/__init__.py b/monai/data/__init__.py index 3678270232..8d8297deaf 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -122,18 +122,31 @@ from multiprocessing.reduction import ForkingPickler def _rebuild_meta(cls, storage, dtype, metadata): - storage_offset, size, stride, meta_dict = metadata + storage_offset, size, stride, requires_grad, meta_dict = metadata + storage = storage._untyped_storage if hasattr(storage, "_untyped_storage") else storage t = cls([], dtype=dtype, device=storage.device) - t.set_(storage._untyped() if hasattr(storage, "_untyped") else storage, storage_offset, size, stride) + t.set_(storage, storage_offset, size, stride) + t.requires_grad = requires_grad t.__dict__ = meta_dict return t def reduce_meta_tensor(meta_tensor): - storage = meta_tensor.untyped_storage() if hasattr(meta_tensor, "untyped_storage") else meta_tensor.storage() + if hasattr(meta_tensor, "untyped_storage"): + storage = meta_tensor.untyped_storage() + elif hasattr(meta_tensor, "_typed_storage"): # gh pytorch 44dac51/torch/_tensor.py#L231-L233 + storage = meta_tensor._typed_storage() + else: + storage = meta_tensor.storage() dtype = meta_tensor.dtype - if storage.is_cuda: + if meta_tensor.is_cuda: raise NotImplementedError("sharing CUDA metatensor across processes not implemented") - metadata = (meta_tensor.storage_offset(), meta_tensor.size(), meta_tensor.stride(), meta_tensor.__dict__) + metadata = ( + meta_tensor.storage_offset(), + meta_tensor.size(), + meta_tensor.stride(), + meta_tensor.requires_grad, + meta_tensor.__dict__, + ) return _rebuild_meta, (type(meta_tensor), storage, dtype, metadata) ForkingPickler.register(MetaTensor, reduce_meta_tensor) diff --git a/tests/test_meta_tensor.py b/tests/test_meta_tensor.py index 793a845063..4f2cb9636a 100644 --- a/tests/test_meta_tensor.py +++ b/tests/test_meta_tensor.py @@ -247,7 +247,7 @@ def test_torchscript(self, device): "your pytorch version if this is important to you." ) im_conv = im_conv.as_tensor() - self.check(out, im_conv, ids=False) + self.check(out, im_conv, ids=False) def test_pickling(self): m, _ = self.get_im() @@ -258,7 +258,7 @@ def test_pickling(self): if not isinstance(m2, MetaTensor) and not pytorch_after(1, 8, 1): warnings.warn("Old version of pytorch. pickling converts `MetaTensor` to `torch.Tensor`.") m = m.as_tensor() - self.check(m2, m, ids=False) + self.check(m2, m, ids=False) @skip_if_no_cuda def test_amp(self):