Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

6104 remove deprecated tensor.storage usage #6105

Merged
merged 1 commit into from
Mar 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 18 additions & 5 deletions monai/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,18 +122,31 @@
from multiprocessing.reduction import ForkingPickler

def _rebuild_meta(cls, storage, dtype, metadata):
storage_offset, size, stride, meta_dict = metadata
storage_offset, size, stride, requires_grad, meta_dict = metadata
storage = storage._untyped_storage if hasattr(storage, "_untyped_storage") else storage
t = cls([], dtype=dtype, device=storage.device)
t.set_(storage._untyped() if hasattr(storage, "_untyped") else storage, storage_offset, size, stride)
t.set_(storage, storage_offset, size, stride)
t.requires_grad = requires_grad
t.__dict__ = meta_dict
return t

def reduce_meta_tensor(meta_tensor):
storage = meta_tensor.untyped_storage() if hasattr(meta_tensor, "untyped_storage") else meta_tensor.storage()
if hasattr(meta_tensor, "untyped_storage"):
storage = meta_tensor.untyped_storage()
elif hasattr(meta_tensor, "_typed_storage"): # gh pytorch 44dac51/torch/_tensor.py#L231-L233
storage = meta_tensor._typed_storage()
else:
storage = meta_tensor.storage()
dtype = meta_tensor.dtype
if storage.is_cuda:
if meta_tensor.is_cuda:
raise NotImplementedError("sharing CUDA metatensor across processes not implemented")
metadata = (meta_tensor.storage_offset(), meta_tensor.size(), meta_tensor.stride(), meta_tensor.__dict__)
metadata = (
meta_tensor.storage_offset(),
meta_tensor.size(),
meta_tensor.stride(),
meta_tensor.requires_grad,
meta_tensor.__dict__,
)
return _rebuild_meta, (type(meta_tensor), storage, dtype, metadata)

ForkingPickler.register(MetaTensor, reduce_meta_tensor)
4 changes: 2 additions & 2 deletions tests/test_meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def test_torchscript(self, device):
"your pytorch version if this is important to you."
)
im_conv = im_conv.as_tensor()
self.check(out, im_conv, ids=False)
self.check(out, im_conv, ids=False)

def test_pickling(self):
m, _ = self.get_im()
Expand All @@ -258,7 +258,7 @@ def test_pickling(self):
if not isinstance(m2, MetaTensor) and not pytorch_after(1, 8, 1):
warnings.warn("Old version of pytorch. pickling converts `MetaTensor` to `torch.Tensor`.")
m = m.as_tensor()
self.check(m2, m, ids=False)
self.check(m2, m, ids=False)

@skip_if_no_cuda
def test_amp(self):
Expand Down