diff --git a/monai/data/__init__.py b/monai/data/__init__.py index 8a560e5f2f..2e8808d3d1 100644 --- a/monai/data/__init__.py +++ b/monai/data/__init__.py @@ -117,18 +117,19 @@ with contextlib.suppress(BaseException): from multiprocessing.reduction import ForkingPickler - def _rebuild_meta(cls, storage, metadata): + def _rebuild_meta(cls, storage, dtype, metadata): storage_offset, size, stride, meta_dict = metadata - t = cls([], dtype=storage.dtype, device=storage.device) + t = cls([], dtype=dtype, device=storage.device) t.set_(storage._untyped() if hasattr(storage, "_untyped") else storage, storage_offset, size, stride) t.__dict__ = meta_dict return t def reduce_meta_tensor(meta_tensor): - storage = meta_tensor.storage() + storage = meta_tensor.untyped_storage() if hasattr(meta_tensor, "untyped_storage") else meta_tensor.storage() + dtype = meta_tensor.dtype if storage.is_cuda: raise NotImplementedError("sharing CUDA metatensor across processes not implemented") metadata = (meta_tensor.storage_offset(), meta_tensor.size(), meta_tensor.stride(), meta_tensor.__dict__) - return _rebuild_meta, (type(meta_tensor), storage, metadata) + return _rebuild_meta, (type(meta_tensor), storage, dtype, metadata) ForkingPickler.register(MetaTensor, reduce_meta_tensor)