Skip to content

Commit

Permalink
use untyped_storage if present (#5863)
Browse files Browse the repository at this point in the history
Fixes #5862.

### Description

if `untyped_storage()` is present (pytorch 2) use it, else use
`storage()`.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

Signed-off-by: Richard Brown <[email protected]>
  • Loading branch information
rijobro authored Jan 17, 2023
1 parent 6803061 commit 79b8b0c
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions monai/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,18 +117,19 @@
with contextlib.suppress(BaseException):
from multiprocessing.reduction import ForkingPickler

def _rebuild_meta(cls, storage, metadata):
def _rebuild_meta(cls, storage, dtype, metadata):
storage_offset, size, stride, meta_dict = metadata
t = cls([], dtype=storage.dtype, device=storage.device)
t = cls([], dtype=dtype, device=storage.device)
t.set_(storage._untyped() if hasattr(storage, "_untyped") else storage, storage_offset, size, stride)
t.__dict__ = meta_dict
return t

def reduce_meta_tensor(meta_tensor):
storage = meta_tensor.storage()
storage = meta_tensor.untyped_storage() if hasattr(meta_tensor, "untyped_storage") else meta_tensor.storage()
dtype = meta_tensor.dtype
if storage.is_cuda:
raise NotImplementedError("sharing CUDA metatensor across processes not implemented")
metadata = (meta_tensor.storage_offset(), meta_tensor.size(), meta_tensor.stride(), meta_tensor.__dict__)
return _rebuild_meta, (type(meta_tensor), storage, metadata)
return _rebuild_meta, (type(meta_tensor), storage, dtype, metadata)

ForkingPickler.register(MetaTensor, reduce_meta_tensor)

0 comments on commit 79b8b0c

Please sign in to comment.