Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

909 Update LoadImage to use Nibabel as default #1307

Merged
merged 20 commits into from
Dec 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/highlights.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ For example:
```py
# define a transform chain for pre-processing
train_transforms = monai.transforms.Compose([
LoadNiftid(keys=['image', 'label']),
LoadImaged(keys=['image', 'label']),
RandRotate90d(keys=['image', 'label'], prob=0.2, spatial_axes=[0, 2]),
... ...
])
Expand Down
6 changes: 3 additions & 3 deletions monai/apps/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
partition_dataset,
select_cross_validation_folds,
)
from monai.transforms import LoadNiftid, LoadPNGd, Randomizable
from monai.transforms import LoadImaged, Randomizable
from monai.utils import ensure_tuple


Expand Down Expand Up @@ -92,7 +92,7 @@ def __init__(
)
data = self._generate_data_list(dataset_dir)
if transform == ():
transform = LoadPNGd("image")
transform = LoadImaged("image")
super().__init__(data, transform, cache_num=cache_num, cache_rate=cache_rate, num_workers=num_workers)

def randomize(self, data: Optional[Any] = None) -> None:
Expand Down Expand Up @@ -268,7 +268,7 @@ def __init__(
]
self._properties = load_decathlon_properties(os.path.join(dataset_dir, "dataset.json"), property_keys)
if transform == ():
transform = LoadNiftid(["image", "label"])
transform = LoadImaged(["image", "label"])
super().__init__(data, transform, cache_num=cache_num, cache_rate=cache_rate, num_workers=num_workers)

def get_indices(self) -> np.ndarray:
Expand Down
123 changes: 62 additions & 61 deletions monai/data/image_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
from torch.utils.data._utils.collate import np_str_obj_array_pattern

from monai.config import KeysCollection
from monai.data.utils import correct_nifti_header_if_necessary
Expand Down Expand Up @@ -80,6 +81,29 @@ def get_data(self, img) -> Tuple[np.ndarray, Dict]:
raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")


def _copy_compatible_dict(from_dict: Dict, to_dict: Dict):
if not isinstance(to_dict, dict):
raise ValueError(f"to_dict must be a Dict, got {type(to_dict)}.")
if not to_dict:
for key in from_dict:
datum = from_dict[key]
if isinstance(datum, np.ndarray) and np_str_obj_array_pattern.search(datum.dtype.str) is not None:
continue
to_dict[key] = datum
else:
affine_key, shape_key = "affine", "spatial_shape"
if affine_key in from_dict and not np.allclose(from_dict[affine_key], to_dict[affine_key]):
raise RuntimeError(
"affine matrix of all images should be the same for channel-wise concatenation. "
f"Got {from_dict[affine_key]} and {to_dict[affine_key]}."
)
if shape_key in from_dict and not np.allclose(from_dict[shape_key], to_dict[shape_key]):
raise RuntimeError(
"spatial_shape of all images should be the same for channel-wise concatenation. "
f"Got {from_dict[shape_key]} and {to_dict[shape_key]}."
)


class ITKReader(ImageReader):
"""
Load medical images based on ITK library.
Expand Down Expand Up @@ -159,22 +183,15 @@ def get_data(self, img):

"""
img_array: List[np.ndarray] = list()
compatible_meta: Dict = None
compatible_meta: Dict = {}

for i in ensure_tuple(img):
header = self._get_meta_dict(i)
header["original_affine"] = self._get_affine(i)
header["affine"] = header["original_affine"].copy()
header["spatial_shape"] = self._get_spatial_shape(i)
img_array.append(self._get_array_data(i))

if compatible_meta is None:
compatible_meta = header
else:
if not np.allclose(header["affine"], compatible_meta["affine"]):
raise RuntimeError("affine matrix of all images should be same.")
if not np.allclose(header["spatial_shape"], compatible_meta["spatial_shape"]):
raise RuntimeError("spatial_shape of all images should be same.")
_copy_compatible_dict(header, compatible_meta)

img_array_ = np.stack(img_array, axis=0) if len(img_array) > 1 else img_array[0]
return img_array_, compatible_meta
Expand All @@ -188,7 +205,7 @@ def _get_meta_dict(self, img) -> Dict:

"""
img_meta_dict = img.GetMetaDataDictionary()
meta_dict = dict()
meta_dict = {}
for key in img_meta_dict.GetKeys():
# ignore deprecated, legacy members that cause issues
if key.startswith("ITK_original_"):
Expand Down Expand Up @@ -220,7 +237,7 @@ def _get_affine(self, img) -> np.ndarray:
affine[(slice(-1), -1)] = origin
return affine

def _get_spatial_shape(self, img) -> Sequence:
def _get_spatial_shape(self, img) -> np.ndarray:
"""
Get the spatial shape of image data, it doesn't contain the channel dim.

Expand All @@ -230,7 +247,7 @@ def _get_spatial_shape(self, img) -> Sequence:
"""
shape = list(itk.size(img))
shape.reverse()
return shape
return np.asarray(shape)

def _get_array_data(self, img) -> np.ndarray:
"""
Expand All @@ -247,17 +264,15 @@ def _get_array_data(self, img) -> np.ndarray:
channels = img.GetNumberOfComponentsPerPixel()
if channels == 1:
return itk.array_view_from_image(img, keep_axes=False)
else:
# The memory layout of itk.Image has all pixel's channels adjacent
# in memory, i.e. R1G1B1R2G2B2R3G3B3. For PyTorch/MONAI, we need
# channels to be contiguous, i.e. R1R2R3G1G2G3B1B2B3.
arr = itk.array_view_from_image(img, keep_axes=False)
dest = list(range(img.ndim))
source = dest.copy()
end = source.pop()
source.insert(0, end)
arr_contiguous_channels = np.moveaxis(arr, source, dest)
return arr_contiguous_channels
# The memory layout of itk.Image has all pixel's channels adjacent
# in memory, i.e. R1G1B1R2G2B2R3G3B3. For PyTorch/MONAI, we need
# channels to be contiguous, i.e. R1R2R3G1G2G3B1B2B3.
arr = itk.array_view_from_image(img, keep_axes=False)
dest = list(range(img.ndim))
source = dest.copy()
end = source.pop()
source.insert(0, end)
return np.moveaxis(arr, source, dest)


class NibabelReader(ImageReader):
Expand All @@ -271,9 +286,10 @@ class NibabelReader(ImageReader):

"""

def __init__(self, as_closest_canonical: bool = False, **kwargs):
def __init__(self, as_closest_canonical: bool = False, dtype: Optional[np.dtype] = np.float32, **kwargs):
super().__init__()
self.as_closest_canonical = as_closest_canonical
self.dtype = dtype
self.kwargs = kwargs

def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool:
Expand Down Expand Up @@ -324,26 +340,19 @@ def get_data(self, img):

"""
img_array: List[np.ndarray] = list()
compatible_meta: Dict = None
compatible_meta: Dict = {}

for i in ensure_tuple(img):
header = self._get_meta_dict(i)
header["affine"] = self._get_affine(i)
header["original_affine"] = self._get_affine(i)
header["affine"] = header["original_affine"].copy()
header["as_closest_canonical"] = self.as_closest_canonical
if self.as_closest_canonical:
i = nib.as_closest_canonical(i)
header["affine"] = self._get_affine(i)
header["as_closest_canonical"] = self.as_closest_canonical
header["spatial_shape"] = self._get_spatial_shape(i)
img_array.append(self._get_array_data(i))

if compatible_meta is None:
compatible_meta = header
else:
if not np.allclose(header["affine"], compatible_meta["affine"]):
raise RuntimeError("affine matrix of all images should be same.")
if not np.allclose(header["spatial_shape"], compatible_meta["spatial_shape"]):
raise RuntimeError("spatial_shape of all images should be same.")
_copy_compatible_dict(header, compatible_meta)

img_array_ = np.stack(img_array, axis=0) if len(img_array) > 1 else img_array[0]
return img_array_, compatible_meta
Expand All @@ -367,9 +376,9 @@ def _get_affine(self, img) -> np.ndarray:
img: a Nibabel image object loaded from a image file.

"""
return img.affine
return img.affine.copy()

def _get_spatial_shape(self, img) -> Sequence:
def _get_spatial_shape(self, img) -> np.ndarray:
"""
Get the spatial shape of image data, it doesn't contain the channel dim.

Expand All @@ -379,7 +388,7 @@ def _get_spatial_shape(self, img) -> Sequence:
"""
ndim = img.header["dim"][0]
spatial_rank = min(ndim, 3)
return list(img.header["dim"][1 : spatial_rank + 1])
return np.asarray(img.header["dim"][1 : spatial_rank + 1])

def _get_array_data(self, img) -> np.ndarray:
"""
Expand All @@ -389,7 +398,9 @@ def _get_array_data(self, img) -> np.ndarray:
img: a Nibabel image object loaded from a image file.

"""
return np.asarray(img.dataobj)
_array = np.array(img.get_fdata(dtype=self.dtype))
img.uncache()
return _array


class NumpyReader(ImageReader):
Expand Down Expand Up @@ -466,7 +477,7 @@ def get_data(self, img):

"""
img_array: List[np.ndarray] = list()
compatible_meta: Dict = None
compatible_meta: Dict = {}
if isinstance(img, np.ndarray):
img = (img,)

Expand All @@ -475,12 +486,7 @@ def get_data(self, img):
if isinstance(i, np.ndarray):
header["spatial_shape"] = i.shape
img_array.append(i)

if compatible_meta is None:
compatible_meta = header
else:
if not np.allclose(header["spatial_shape"], compatible_meta["spatial_shape"]):
raise RuntimeError("spatial_shape of all images should be same.")
_copy_compatible_dict(header, compatible_meta)

img_array_ = np.stack(img_array, axis=0) if len(img_array) > 1 else img_array[0]
return img_array_, compatible_meta
Expand Down Expand Up @@ -551,18 +557,13 @@ def get_data(self, img):

"""
img_array: List[np.ndarray] = list()
compatible_meta: Dict = None
compatible_meta: Dict = {}

for i in ensure_tuple(img):
header = self._get_meta_dict(i)
header["spatial_shape"] = self._get_spatial_shape(i)
img_array.append(np.asarray(i))

if compatible_meta is None:
compatible_meta = header
else:
if not np.allclose(header["spatial_shape"], compatible_meta["spatial_shape"]):
raise RuntimeError("spatial_shape of all images should be same.")
_copy_compatible_dict(header, compatible_meta)

img_array_ = np.stack(img_array, axis=0) if len(img_array) > 1 else img_array[0]
return img_array_, compatible_meta
Expand All @@ -574,17 +575,17 @@ def _get_meta_dict(self, img) -> Dict:
img: a PIL Image object loaded from a image file.

"""
meta = dict()
meta["format"] = img.format
meta["mode"] = img.mode
meta["width"] = img.width
meta["height"] = img.height
return meta
return {
"format": img.format,
"mode": img.mode,
"width": img.width,
"height": img.height,
}

def _get_spatial_shape(self, img) -> Sequence:
def _get_spatial_shape(self, img) -> np.ndarray:
"""
Get the spatial shape of image data, it doesn't contain the channel dim.
Args:
img: a PIL Image object loaded from a image file.
"""
return [img.width, img.height]
return np.asarray((img.width, img.height))
1 change: 0 additions & 1 deletion monai/data/iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,4 @@ def __iter__(self):
for data in self.source:
if self.transform is not None:
data = apply_transform(self.transform, data)

yield data
Loading