Skip to content

Commit

Permalink
6104 remove deprecated tensor.storage usage (#6105)
Browse files Browse the repository at this point in the history
Fixes #6104

### Description

workaround adapted from
https://github.com/pytorch/pytorch/blob/44dac51/torch/_tensor.py#L231-L233

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

Signed-off-by: Wenqi Li <[email protected]>
  • Loading branch information
wyli authored Mar 6, 2023
1 parent e375f2a commit fa884a2
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 7 deletions.
23 changes: 18 additions & 5 deletions monai/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,18 +122,31 @@
from multiprocessing.reduction import ForkingPickler

def _rebuild_meta(cls, storage, dtype, metadata):
storage_offset, size, stride, meta_dict = metadata
storage_offset, size, stride, requires_grad, meta_dict = metadata
storage = storage._untyped_storage if hasattr(storage, "_untyped_storage") else storage
t = cls([], dtype=dtype, device=storage.device)
t.set_(storage._untyped() if hasattr(storage, "_untyped") else storage, storage_offset, size, stride)
t.set_(storage, storage_offset, size, stride)
t.requires_grad = requires_grad
t.__dict__ = meta_dict
return t

def reduce_meta_tensor(meta_tensor):
storage = meta_tensor.untyped_storage() if hasattr(meta_tensor, "untyped_storage") else meta_tensor.storage()
if hasattr(meta_tensor, "untyped_storage"):
storage = meta_tensor.untyped_storage()
elif hasattr(meta_tensor, "_typed_storage"): # gh pytorch 44dac51/torch/_tensor.py#L231-L233
storage = meta_tensor._typed_storage()
else:
storage = meta_tensor.storage()
dtype = meta_tensor.dtype
if storage.is_cuda:
if meta_tensor.is_cuda:
raise NotImplementedError("sharing CUDA metatensor across processes not implemented")
metadata = (meta_tensor.storage_offset(), meta_tensor.size(), meta_tensor.stride(), meta_tensor.__dict__)
metadata = (
meta_tensor.storage_offset(),
meta_tensor.size(),
meta_tensor.stride(),
meta_tensor.requires_grad,
meta_tensor.__dict__,
)
return _rebuild_meta, (type(meta_tensor), storage, dtype, metadata)

ForkingPickler.register(MetaTensor, reduce_meta_tensor)
4 changes: 2 additions & 2 deletions tests/test_meta_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def test_torchscript(self, device):
"your pytorch version if this is important to you."
)
im_conv = im_conv.as_tensor()
self.check(out, im_conv, ids=False)
self.check(out, im_conv, ids=False)

def test_pickling(self):
m, _ = self.get_im()
Expand All @@ -258,7 +258,7 @@ def test_pickling(self):
if not isinstance(m2, MetaTensor) and not pytorch_after(1, 8, 1):
warnings.warn("Old version of pytorch. pickling converts `MetaTensor` to `torch.Tensor`.")
m = m.as_tensor()
self.check(m2, m, ids=False)
self.check(m2, m, ids=False)

@skip_if_no_cuda
def test_amp(self):
Expand Down

0 comments on commit fa884a2

Please sign in to comment.