Skip to content

Commit

Permalink
909 Add NumpyReader for IO factory (#964)
Browse files Browse the repository at this point in the history
* [DLMED] add NumpyReader

* [DLMED] add NumpyReader

* [MONAI] python code formatting

Co-authored-by: monai-bot <[email protected]>
  • Loading branch information
Nic-Ma and monai-bot authored Aug 28, 2020
1 parent 5fff91e commit 35ef648
Show file tree
Hide file tree
Showing 3 changed files with 187 additions and 0 deletions.
5 changes: 5 additions & 0 deletions docs/source/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ NibabelReader
.. autoclass:: NibabelReader
:members:

NumpyReader
~~~~~~~~~~~
.. autoclass:: NunpyReader
:members:


Nifti format handling
---------------------
Expand Down
92 changes: 92 additions & 0 deletions monai/data/image_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import numpy as np

from monai.config import KeysCollection
from monai.data.utils import correct_nifti_header_if_necessary
from monai.utils import ensure_tuple, optional_import

Expand Down Expand Up @@ -348,3 +349,94 @@ def _get_array_data(self, img: Nifti1Image) -> np.ndarray:
"""
return np.asarray(img.dataobj)


class NumpyReader(ImageReader):
"""
Load NPY or NPZ format data based on Numpy library, they can be arrays or pickled objects.
A typical usage is to load the `mask` data for classification task.
It can load part of the npz file with specified `npz_keys`.
Args:
npz_keys: if loading npz file, only load the specified keys, if None, load all the items.
stack the loaded items together to construct a new first dimension.
"""

def __init__(self, npz_keys: Optional[KeysCollection] = None):
super().__init__()
self._img: Optional[Sequence[Nifti1Image]] = None
if npz_keys is not None:
npz_keys = ensure_tuple(npz_keys)
self.npz_keys = npz_keys

def verify_suffix(self, filename: Union[Sequence[str], str]) -> bool:
"""
Verify whether the specified file or files format is supported by Numpy reader.
Args:
filename: file name or a list of file names to read.
if a list of files, verify all the subffixes.
"""
suffixes: Sequence[str] = ["npz", "npy"]
return is_supported_format(filename, suffixes)

def read(self, data: Union[Sequence[str], str, np.ndarray], **kwargs):
"""
Read image data from specified file or files, or set a Numpy array.
Note that the returned object is Numpy array or list of Numpy arrays.
`self._img` is always a list, even only has 1 image.
Args:
data: file name or a list of file names to read.
kwargs: additional args for `numpy.load` API except `allow_pickle`. more details about available args:
https://numpy.org/doc/stable/reference/generated/numpy.load.html
"""
self._img = list()
if isinstance(data, np.ndarray):
self._img.append(data)
return data

filenames: Sequence[str] = ensure_tuple(data)
for name in filenames:
img = np.load(name, allow_pickle=True, **kwargs)
if name.endswith(".npz"):
# load expected items from NPZ file
npz_keys = [f"arr_{i}" for i in range(len(img))] if self.npz_keys is None else self.npz_keys
for k in npz_keys:
self._img.append(img[k])
else:
self._img.append(img)

return self._img if len(filenames) > 1 else self._img[0]

def get_data(self):
"""
Extract data array and meta data from loaded data and return them.
This function returns 2 objects, first is numpy array of image data, second is dict of meta data.
It constructs `spatial_shape=data.shape` and stores in meta dict if the data is numpy array.
If loading a list of files, stack them together and add a new dimension as first dimension,
and use the meta data of the first image to represent the stacked result.
"""
img_array: List[np.ndarray] = list()
compatible_meta: Dict = None
if self._img is None:
raise RuntimeError("please call read() first then use get_data().")

for img in self._img:
header = dict()
if isinstance(img, np.ndarray):
header["spatial_shape"] = img.shape
img_array.append(img)

if compatible_meta is None:
compatible_meta = header
else:
if not np.allclose(header["spatial_shape"], compatible_meta["spatial_shape"]):
raise RuntimeError("spatial_shape of all images should be same.")

img_array_ = np.stack(img_array, axis=0) if len(img_array) > 1 else img_array[0]
return img_array_, compatible_meta
90 changes: 90 additions & 0 deletions tests/test_numpy_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import tempfile
import unittest

import numpy as np

from monai.data import NumpyReader


class TestNumpyReader(unittest.TestCase):
def test_npy(self):
test_data = np.random.randint(0, 256, size=[3, 4, 4])
with tempfile.TemporaryDirectory() as tempdir:
filepath = os.path.join(tempdir, "test_data.npy")
np.save(filepath, test_data)

reader = NumpyReader()
reader.read(filepath)
result = reader.get_data()
self.assertTupleEqual(result[1]["spatial_shape"], test_data.shape)
self.assertTupleEqual(result[0].shape, test_data.shape)
np.testing.assert_allclose(result[0], test_data)

def test_npz1(self):
test_data1 = np.random.randint(0, 256, size=[3, 4, 4])
with tempfile.TemporaryDirectory() as tempdir:
filepath = os.path.join(tempdir, "test_data.npy")
np.save(filepath, test_data1)

reader = NumpyReader()
reader.read(filepath)
result = reader.get_data()
self.assertTupleEqual(result[1]["spatial_shape"], test_data1.shape)
self.assertTupleEqual(result[0].shape, test_data1.shape)
np.testing.assert_allclose(result[0], test_data1)

def test_npz2(self):
test_data1 = np.random.randint(0, 256, size=[3, 4, 4])
test_data2 = np.random.randint(0, 256, size=[3, 4, 4])
with tempfile.TemporaryDirectory() as tempdir:
filepath = os.path.join(tempdir, "test_data.npz")
np.savez(filepath, test_data1, test_data2)

reader = NumpyReader()
reader.read(filepath)
result = reader.get_data()
self.assertTupleEqual(result[1]["spatial_shape"], test_data1.shape)
self.assertTupleEqual(result[0].shape, (2, 3, 4, 4))
np.testing.assert_allclose(result[0], np.stack([test_data1, test_data2]))

def test_npz3(self):
test_data1 = np.random.randint(0, 256, size=[3, 4, 4])
test_data2 = np.random.randint(0, 256, size=[3, 4, 4])
with tempfile.TemporaryDirectory() as tempdir:
filepath = os.path.join(tempdir, "test_data.npz")
np.savez(filepath, test1=test_data1, test2=test_data2)

reader = NumpyReader(npz_keys=["test1", "test2"])
reader.read(filepath)
result = reader.get_data()
self.assertTupleEqual(result[1]["spatial_shape"], test_data1.shape)
self.assertTupleEqual(result[0].shape, (2, 3, 4, 4))
np.testing.assert_allclose(result[0], np.stack([test_data1, test_data2]))

def test_npy_pickle(self):
test_data = {"test": np.random.randint(0, 256, size=[3, 4, 4])}
with tempfile.TemporaryDirectory() as tempdir:
filepath = os.path.join(tempdir, "test_data.npy")
np.save(filepath, test_data, allow_pickle=True)

reader = NumpyReader()
reader.read(filepath)
result = reader.get_data()[0].item()
self.assertTupleEqual(result["test"].shape, test_data["test"].shape)
np.testing.assert_allclose(result["test"], test_data["test"])


if __name__ == "__main__":
unittest.main()

0 comments on commit 35ef648

Please sign in to comment.