Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

909 develop ImageDataset to replace NiftiDataset #1461

Merged
merged 7 commits into from
Jan 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions docs/source/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ Generic Interfaces
:members:
:special-members: __getitem__

`ImageDataset`
~~~~~~~~~~~~~~
.. autoclass:: ImageDataset
:members:
:special-members: __getitem__

Patch-based dataset
-------------------
Expand Down Expand Up @@ -104,11 +109,6 @@ PILReader
Nifti format handling
---------------------

Reading
~~~~~~~
.. autoclass:: monai.data.NiftiDataset
:members:

Writing Nifti
~~~~~~~~~~~~~
.. autoclass:: monai.data.NiftiSaver
Expand Down
2 changes: 1 addition & 1 deletion monai/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@
)
from .decathlon_datalist import load_decathlon_datalist, load_decathlon_properties
from .grid_dataset import GridPatchDataset, PatchDataset
from .image_dataset import ImageDataset
from .image_reader import ImageReader, ITKReader, NibabelReader, NumpyReader, PILReader
from .iterable_dataset import IterableDataset
from .nifti_reader import NiftiDataset
from .nifti_saver import NiftiSaver
from .nifti_writer import write_nifti
from .png_saver import PNGSaver
Expand Down
51 changes: 28 additions & 23 deletions monai/data/nifti_reader.py → monai/data/image_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,31 +9,37 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Callable, Optional, Sequence
from typing import Any, Callable, Optional, Sequence, Union

import numpy as np
from torch.utils.data import Dataset

from monai.data.image_reader import ImageReader
from monai.transforms import LoadImage, Randomizable, apply_transform
from monai.utils import MAX_SEED, get_seed


class NiftiDataset(Dataset, Randomizable):
class ImageDataset(Dataset, Randomizable):
"""
Loads image/segmentation pairs of Nifti files from the given filename lists. Transformations can be specified
Loads image/segmentation pairs of files from the given filename lists. Transformations can be specified
for the image and segmentation arrays separately.
The difference between this dataset and `ArrayDataset` is that this dataset can apply transform chain to images
and segs and return both the images and metadata, and no need to specify transform to load images from files.

"""

def __init__(
self,
image_files: Sequence[str],
seg_files: Optional[Sequence[str]] = None,
labels: Optional[Sequence[float]] = None,
as_closest_canonical: bool = False,
transform: Optional[Callable] = None,
seg_transform: Optional[Callable] = None,
image_only: bool = True,
dtype: Optional[np.dtype] = np.float32,
reader: Optional[Union[ImageReader, str]] = None,
*args,
**kwargs,
) -> None:
"""
Initializes the dataset with the image and segmentation filename lists. The transform `transform` is applied
Expand All @@ -43,14 +49,18 @@ def __init__(
image_files: list of image filenames
seg_files: if in segmentation task, list of segmentation filenames
labels: if in classification task, list of classification labels
as_closest_canonical: if True, load the image as closest to canonical orientation
transform: transform to apply to image arrays
seg_transform: transform to apply to segmentation arrays
image_only: if True return only the image volume, other return image volume and header dict
image_only: if True return only the image volume, otherwise, return image volume and the metadata
dtype: if not None convert the loaded image to this data type
reader: register reader to load image file and meta data, if None, will use the default readers.
If a string of reader name provided, will construct a reader object with the `*args` and `**kwargs`
parameters, supported reader name: "NibabelReader", "PILReader", "ITKReader", "NumpyReader"
args: additional parameters for reader if providing a reader name
kwargs: additional parameters for reader if providing a reader name

Raises:
ValueError: When ``seg_files`` length differs from ``image_files``.
ValueError: When ``seg_files`` length differs from ``image_files``

"""

Expand All @@ -63,13 +73,11 @@ def __init__(
self.image_files = image_files
self.seg_files = seg_files
self.labels = labels
self.as_closest_canonical = as_closest_canonical
self.transform = transform
self.seg_transform = seg_transform
self.image_only = image_only
self.dtype = dtype
self.loader = LoadImage(reader, image_only, dtype, *args, **kwargs)
self.set_random_state(seed=get_seed())

self._seed = 0 # transform synchronization seed

def __len__(self) -> int:
Expand All @@ -81,21 +89,18 @@ def randomize(self, data: Optional[Any] = None) -> None:
def __getitem__(self, index: int):
self.randomize()
meta_data = None
img_loader = LoadImage(
reader="NibabelReader",
image_only=self.image_only,
dtype=self.dtype,
as_closest_canonical=self.as_closest_canonical,
)
if self.image_only:
img = img_loader(self.image_files[index])
else:
img, meta_data = img_loader(self.image_files[index])
seg = None
if self.seg_files is not None:
seg_loader = LoadImage(image_only=True)
seg = seg_loader(self.seg_files[index])
label = None

if self.image_only:
img = self.loader(self.image_files[index])
if self.seg_files is not None:
seg = self.loader(self.seg_files[index])
else:
img, meta_data = self.loader(self.image_files[index])
if self.seg_files is not None:
seg, _ = self.loader(self.seg_files[index])

if self.labels is not None:
label = self.labels[index]

Expand Down
2 changes: 1 addition & 1 deletion tests/min_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def run_testsuit():
"test_load_imaged",
"test_load_spacing_orientation",
"test_mednistdataset",
"test_nifti_dataset",
"test_image_dataset",
"test_nifti_header_revise",
"test_nifti_rw",
"test_nifti_saver",
Expand Down
24 changes: 12 additions & 12 deletions tests/test_nifti_dataset.py → tests/test_image_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import nibabel as nib
import numpy as np

from monai.data import NiftiDataset
from monai.data import ImageDataset
from monai.transforms import Randomizable

FILENAMES = ["test1.nii.gz", "test2.nii", "test3.nii.gz"]
Expand All @@ -35,7 +35,7 @@ def __call__(self, data):
return data + self._a


class TestNiftiDataset(unittest.TestCase):
class TestImageDataset(unittest.TestCase):
def test_dataset(self):
with tempfile.TemporaryDirectory() as tempdir:
full_names, ref_data = [], []
Expand All @@ -47,46 +47,46 @@ def test_dataset(self):
nib.save(nib.Nifti1Image(test_image, np.eye(4)), save_path)

# default loading no meta
dataset = NiftiDataset(full_names)
dataset = ImageDataset(full_names)
for d, ref in zip(dataset, ref_data):
np.testing.assert_allclose(d, ref, atol=1e-3)

# loading no meta, int
dataset = NiftiDataset(full_names, dtype=np.float16)
dataset = ImageDataset(full_names, dtype=np.float16)
for d, _ in zip(dataset, ref_data):
self.assertEqual(d.dtype, np.float16)

# loading with meta, no transform
dataset = NiftiDataset(full_names, image_only=False)
dataset = ImageDataset(full_names, image_only=False)
for d_tuple, ref in zip(dataset, ref_data):
d, meta = d_tuple
np.testing.assert_allclose(d, ref, atol=1e-3)
np.testing.assert_allclose(meta["original_affine"], np.eye(4))

# loading image/label, no meta
dataset = NiftiDataset(full_names, seg_files=full_names, image_only=True)
dataset = ImageDataset(full_names, seg_files=full_names, image_only=True)
for d_tuple, ref in zip(dataset, ref_data):
img, seg = d_tuple
np.testing.assert_allclose(img, ref, atol=1e-3)
np.testing.assert_allclose(seg, ref, atol=1e-3)

# loading image/label, no meta
dataset = NiftiDataset(full_names, transform=lambda x: x + 1, image_only=True)
dataset = ImageDataset(full_names, transform=lambda x: x + 1, image_only=True)
for d, ref in zip(dataset, ref_data):
np.testing.assert_allclose(d, ref + 1, atol=1e-3)

# set seg transform, but no seg_files
with self.assertRaises(RuntimeError):
dataset = NiftiDataset(full_names, seg_transform=lambda x: x + 1, image_only=True)
dataset = ImageDataset(full_names, seg_transform=lambda x: x + 1, image_only=True)
_ = dataset[0]

# set seg transform, but no seg_files
with self.assertRaises(RuntimeError):
dataset = NiftiDataset(full_names, seg_transform=lambda x: x + 1, image_only=True)
dataset = ImageDataset(full_names, seg_transform=lambda x: x + 1, image_only=True)
_ = dataset[0]

# loading image/label, with meta
dataset = NiftiDataset(
dataset = ImageDataset(
full_names,
transform=lambda x: x + 1,
seg_files=full_names,
Expand All @@ -100,7 +100,7 @@ def test_dataset(self):
np.testing.assert_allclose(meta["original_affine"], np.eye(4), atol=1e-3)

# loading image/label, with meta
dataset = NiftiDataset(
dataset = ImageDataset(
full_names, transform=lambda x: x + 1, seg_files=full_names, labels=[1, 2, 3], image_only=False
)
for idx, (d_tuple, ref) in enumerate(zip(dataset, ref_data)):
Expand All @@ -111,7 +111,7 @@ def test_dataset(self):
np.testing.assert_allclose(meta["original_affine"], np.eye(4), atol=1e-3)

# loading image/label, with sync. transform
dataset = NiftiDataset(
dataset = ImageDataset(
full_names, transform=RandTest(), seg_files=full_names, seg_transform=RandTest(), image_only=False
)
for d_tuple, ref in zip(dataset, ref_data):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_integration_sliding_window.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from ignite.engine import Engine
from torch.utils.data import DataLoader

from monai.data import NiftiDataset, create_test_image_3d
from monai.data import ImageDataset, create_test_image_3d
from monai.handlers import SegmentationSaver
from monai.inferers import sliding_window_inference
from monai.networks import eval_mode, predict_segmentation
Expand All @@ -30,7 +30,7 @@


def run_test(batch_size, img_name, seg_name, output_dir, device="cuda:0"):
ds = NiftiDataset([img_name], [seg_name], transform=AddChannel(), seg_transform=AddChannel(), image_only=False)
ds = ImageDataset([img_name], [seg_name], transform=AddChannel(), seg_transform=AddChannel(), image_only=False)
loader = DataLoader(ds, batch_size=1, pin_memory=torch.cuda.is_available())

net = UNet(
Expand Down