Skip to content

Commit

Permalink
909 Update LoadImage to use Nibabel as default (#1307)
Browse files Browse the repository at this point in the history
* [DLMED] update to use Nibabel as default reader

Signed-off-by: Nic Ma <[email protected]>

* [MONAI] python code formatting

Signed-off-by: monai-bot <[email protected]>

* [DLMED] fix quick test issue

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] update according to comments

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] fix flake8 issue

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] update according to comments

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] fix flake8 issue

Signed-off-by: Nic Ma <[email protected]>

* [MONAI] python code formatting

Signed-off-by: monai-bot <[email protected]>

* [DLMED] update according to comments

Signed-off-by: Nic Ma <[email protected]>

* [DLMED] fix flake8

Signed-off-by: Nic Ma <[email protected]>

* incorporates the new api in tests

Signed-off-by: Wenqi Li <[email protected]>

* fixes compatible meta dict

Signed-off-by: Wenqi Li <[email protected]>

Co-authored-by: monai-bot <[email protected]>
Co-authored-by: Wenqi Li <[email protected]>
  • Loading branch information
3 people authored Dec 2, 2020
1 parent 8d61ae2 commit 2c903cd
Show file tree
Hide file tree
Showing 27 changed files with 215 additions and 182 deletions.
2 changes: 1 addition & 1 deletion docs/source/highlights.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ For example:
```py
# define a transform chain for pre-processing
train_transforms = monai.transforms.Compose([
LoadNiftid(keys=['image', 'label']),
LoadImaged(keys=['image', 'label']),
RandRotate90d(keys=['image', 'label'], prob=0.2, spatial_axes=[0, 2]),
... ...
])
Expand Down
6 changes: 3 additions & 3 deletions monai/apps/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
partition_dataset,
select_cross_validation_folds,
)
from monai.transforms import LoadNiftid, LoadPNGd, Randomizable
from monai.transforms import LoadImaged, Randomizable
from monai.utils import ensure_tuple


Expand Down Expand Up @@ -92,7 +92,7 @@ def __init__(
)
data = self._generate_data_list(dataset_dir)
if transform == ():
transform = LoadPNGd("image")
transform = LoadImaged("image")
super().__init__(data, transform, cache_num=cache_num, cache_rate=cache_rate, num_workers=num_workers)

def randomize(self, data: Optional[Any] = None) -> None:
Expand Down Expand Up @@ -268,7 +268,7 @@ def __init__(
]
self._properties = load_decathlon_properties(os.path.join(dataset_dir, "dataset.json"), property_keys)
if transform == ():
transform = LoadNiftid(["image", "label"])
transform = LoadImaged(["image", "label"])
super().__init__(data, transform, cache_num=cache_num, cache_rate=cache_rate, num_workers=num_workers)

def get_indices(self) -> np.ndarray:
Expand Down
123 changes: 62 additions & 61 deletions monai/data/image_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
from torch.utils.data._utils.collate import np_str_obj_array_pattern

from monai.config import KeysCollection
from monai.data.utils import correct_nifti_header_if_necessary
Expand Down Expand Up @@ -80,6 +81,29 @@ def get_data(self, img) -> Tuple[np.ndarray, Dict]:
raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")


def _copy_compatible_dict(from_dict: Dict, to_dict: Dict):
if not isinstance(to_dict, dict):
raise ValueError(f"to_dict must be a Dict, got {type(to_dict)}.")
if not to_dict:
for key in from_dict:
datum = from_dict[key]
if isinstance(datum, np.ndarray) and np_str_obj_array_pattern.search(datum.dtype.str) is not None:
continue
to_dict[key] = datum
else:
affine_key, shape_key = "affine", "spatial_shape"
if affine_key in from_dict and not np.allclose(from_dict[affine_key], to_dict[affine_key]):
raise RuntimeError(
"affine matrix of all images should be the same for channel-wise concatenation. "
f"Got {from_dict[affine_key]} and {to_dict[affine_key]}."
)
if shape_key in from_dict and not np.allclose(from_dict[shape_key], to_dict[shape_key]):
raise RuntimeError(
"spatial_shape of all images should be the same for channel-wise concatenation. "
f"Got {from_dict[shape_key]} and {to_dict[shape_key]}."
)


class ITKReader(ImageReader):
"""
Load medical images based on ITK library.
Expand Down Expand Up @@ -159,22 +183,15 @@ def get_data(self, img):
"""
img_array: List[np.ndarray] = list()
compatible_meta: Dict = None
compatible_meta: Dict = {}

for i in ensure_tuple(img):
header = self._get_meta_dict(i)
header["original_affine"] = self._get_affine(i)
header["affine"] = header["original_affine"].copy()
header["spatial_shape"] = self._get_spatial_shape(i)
img_array.append(self._get_array_data(i))

if compatible_meta is None:
compatible_meta = header
else:
if not np.allclose(header["affine"], compatible_meta["affine"]):
raise RuntimeError("affine matrix of all images should be same.")
if not np.allclose(header["spatial_shape"], compatible_meta["spatial_shape"]):
raise RuntimeError("spatial_shape of all images should be same.")
_copy_compatible_dict(header, compatible_meta)

img_array_ = np.stack(img_array, axis=0) if len(img_array) > 1 else img_array[0]
return img_array_, compatible_meta
Expand All @@ -188,7 +205,7 @@ def _get_meta_dict(self, img) -> Dict:
"""
img_meta_dict = img.GetMetaDataDictionary()
meta_dict = dict()
meta_dict = {}
for key in img_meta_dict.GetKeys():
# ignore deprecated, legacy members that cause issues
if key.startswith("ITK_original_"):
Expand Down Expand Up @@ -220,7 +237,7 @@ def _get_affine(self, img) -> np.ndarray:
affine[(slice(-1), -1)] = origin
return affine

def _get_spatial_shape(self, img) -> Sequence:
def _get_spatial_shape(self, img) -> np.ndarray:
"""
Get the spatial shape of image data, it doesn't contain the channel dim.
Expand All @@ -230,7 +247,7 @@ def _get_spatial_shape(self, img) -> Sequence:
"""
shape = list(itk.size(img))
shape.reverse()
return shape
return np.asarray(shape)

def _get_array_data(self, img) -> np.ndarray:
"""
Expand All @@ -247,17 +264,15 @@ def _get_array_data(self, img) -> np.ndarray:
channels = img.GetNumberOfComponentsPerPixel()
if channels == 1:
return itk.array_view_from_image(img, keep_axes=False)
else:
# The memory layout of itk.Image has all pixel's channels adjacent
# in memory, i.e. R1G1B1R2G2B2R3G3B3. For PyTorch/MONAI, we need
# channels to be contiguous, i.e. R1R2R3G1G2G3B1B2B3.
arr = itk.array_view_from_image(img, keep_axes=False)
dest = list(range(img.ndim))
source = dest.copy()
end = source.pop()
source.insert(0, end)
arr_contiguous_channels = np.moveaxis(arr, source, dest)
return arr_contiguous_channels
# The memory layout of itk.Image has all pixel's channels adjacent
# in memory, i.e. R1G1B1R2G2B2R3G3B3. For PyTorch/MONAI, we need
# channels to be contiguous, i.e. R1R2R3G1G2G3B1B2B3.
arr = itk.array_view_from_image(img, keep_axes=False)
dest = list(range(img.ndim))
source = dest.copy()
end = source.pop()
source.insert(0, end)
return np.moveaxis(arr, source, dest)


class NibabelReader(ImageReader):
Expand All @@ -271,9 +286,10 @@ class NibabelReader(ImageReader):
"""

def __init__(self, as_closest_canonical: bool = False, **kwargs):
def __init__(self, as_closest_canonical: bool = False, dtype: Optional[np.dtype] = np.float32, **kwargs):
super().__init__()
self.as_closest_canonical = as_closest_canonical
self.dtype = dtype
self.kwargs = kwargs

def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool:
Expand Down Expand Up @@ -324,26 +340,19 @@ def get_data(self, img):
"""
img_array: List[np.ndarray] = list()
compatible_meta: Dict = None
compatible_meta: Dict = {}

for i in ensure_tuple(img):
header = self._get_meta_dict(i)
header["affine"] = self._get_affine(i)
header["original_affine"] = self._get_affine(i)
header["affine"] = header["original_affine"].copy()
header["as_closest_canonical"] = self.as_closest_canonical
if self.as_closest_canonical:
i = nib.as_closest_canonical(i)
header["affine"] = self._get_affine(i)
header["as_closest_canonical"] = self.as_closest_canonical
header["spatial_shape"] = self._get_spatial_shape(i)
img_array.append(self._get_array_data(i))

if compatible_meta is None:
compatible_meta = header
else:
if not np.allclose(header["affine"], compatible_meta["affine"]):
raise RuntimeError("affine matrix of all images should be same.")
if not np.allclose(header["spatial_shape"], compatible_meta["spatial_shape"]):
raise RuntimeError("spatial_shape of all images should be same.")
_copy_compatible_dict(header, compatible_meta)

img_array_ = np.stack(img_array, axis=0) if len(img_array) > 1 else img_array[0]
return img_array_, compatible_meta
Expand All @@ -367,9 +376,9 @@ def _get_affine(self, img) -> np.ndarray:
img: a Nibabel image object loaded from a image file.
"""
return img.affine
return img.affine.copy()

def _get_spatial_shape(self, img) -> Sequence:
def _get_spatial_shape(self, img) -> np.ndarray:
"""
Get the spatial shape of image data, it doesn't contain the channel dim.
Expand All @@ -379,7 +388,7 @@ def _get_spatial_shape(self, img) -> Sequence:
"""
ndim = img.header["dim"][0]
spatial_rank = min(ndim, 3)
return list(img.header["dim"][1 : spatial_rank + 1])
return np.asarray(img.header["dim"][1 : spatial_rank + 1])

def _get_array_data(self, img) -> np.ndarray:
"""
Expand All @@ -389,7 +398,9 @@ def _get_array_data(self, img) -> np.ndarray:
img: a Nibabel image object loaded from a image file.
"""
return np.asarray(img.dataobj)
_array = np.array(img.get_fdata(dtype=self.dtype))
img.uncache()
return _array


class NumpyReader(ImageReader):
Expand Down Expand Up @@ -466,7 +477,7 @@ def get_data(self, img):
"""
img_array: List[np.ndarray] = list()
compatible_meta: Dict = None
compatible_meta: Dict = {}
if isinstance(img, np.ndarray):
img = (img,)

Expand All @@ -475,12 +486,7 @@ def get_data(self, img):
if isinstance(i, np.ndarray):
header["spatial_shape"] = i.shape
img_array.append(i)

if compatible_meta is None:
compatible_meta = header
else:
if not np.allclose(header["spatial_shape"], compatible_meta["spatial_shape"]):
raise RuntimeError("spatial_shape of all images should be same.")
_copy_compatible_dict(header, compatible_meta)

img_array_ = np.stack(img_array, axis=0) if len(img_array) > 1 else img_array[0]
return img_array_, compatible_meta
Expand Down Expand Up @@ -551,18 +557,13 @@ def get_data(self, img):
"""
img_array: List[np.ndarray] = list()
compatible_meta: Dict = None
compatible_meta: Dict = {}

for i in ensure_tuple(img):
header = self._get_meta_dict(i)
header["spatial_shape"] = self._get_spatial_shape(i)
img_array.append(np.asarray(i))

if compatible_meta is None:
compatible_meta = header
else:
if not np.allclose(header["spatial_shape"], compatible_meta["spatial_shape"]):
raise RuntimeError("spatial_shape of all images should be same.")
_copy_compatible_dict(header, compatible_meta)

img_array_ = np.stack(img_array, axis=0) if len(img_array) > 1 else img_array[0]
return img_array_, compatible_meta
Expand All @@ -574,17 +575,17 @@ def _get_meta_dict(self, img) -> Dict:
img: a PIL Image object loaded from a image file.
"""
meta = dict()
meta["format"] = img.format
meta["mode"] = img.mode
meta["width"] = img.width
meta["height"] = img.height
return meta
return {
"format": img.format,
"mode": img.mode,
"width": img.width,
"height": img.height,
}

def _get_spatial_shape(self, img) -> Sequence:
def _get_spatial_shape(self, img) -> np.ndarray:
"""
Get the spatial shape of image data, it doesn't contain the channel dim.
Args:
img: a PIL Image object loaded from a image file.
"""
return [img.width, img.height]
return np.asarray((img.width, img.height))
1 change: 0 additions & 1 deletion monai/data/iterable_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,4 @@ def __iter__(self):
for data in self.source:
if self.transform is not None:
data = apply_transform(self.transform, data)

yield data
Loading

0 comments on commit 2c903cd

Please sign in to comment.