diff --git a/monai/apps/datasets.py b/monai/apps/datasets.py index 54829cb946..bb10eb6b11 100644 --- a/monai/apps/datasets.py +++ b/monai/apps/datasets.py @@ -22,6 +22,7 @@ import numpy as np from monai.apps.tcia import ( + DCM_FILENAME_REGEX, download_tcia_series_instance, get_tcia_metadata, get_tcia_ref_uid, @@ -442,6 +443,10 @@ class TciaDataset(Randomizable, CacheDataset): specific_tags: tags that will be loaded for "SEG" series. This argument will be used in `monai.data.PydicomReader`. Default is [(0x0008, 0x1115), (0x0008,0x1140), (0x3006, 0x0010), (0x0020,0x000D), (0x0010,0x0010), (0x0010,0x0020), (0x0020,0x0011), (0x0020,0x0012)]. + fname_regex: a regular expression to match the file names when the input is a folder. + If provided, only the matched files will be included. For example, to include the file name + "image_0001.dcm", the regular expression could be `".*image_(\\d+).dcm"`. + Default to `"^(?!.*LICENSE).*"`, ignoring any file name containing `"LICENSE"`. val_frac: percentage of validation fraction in the whole dataset, default is 0.2. seed: random seed to randomly shuffle the datalist before splitting into training and validation, default is 0. note to set same seed for `training` and `validation` sections. @@ -509,6 +514,7 @@ def __init__( (0x0020, 0x0011), # Series Number (0x0020, 0x0012), # Acquisition Number ), + fname_regex: str = DCM_FILENAME_REGEX, seed: int = 0, val_frac: float = 0.2, cache_num: int = sys.maxsize, @@ -548,12 +554,13 @@ def __init__( if not os.path.exists(download_dir): raise RuntimeError(f"Cannot find dataset directory: {download_dir}.") + self.fname_regex = fname_regex self.indices: np.ndarray = np.array([]) self.datalist = self._generate_data_list(download_dir) if transform == (): - transform = LoadImaged(reader="PydicomReader", keys=["image"]) + transform = LoadImaged(keys=["image"], reader="PydicomReader", fname_regex=self.fname_regex) CacheDataset.__init__( self, data=self.datalist, diff --git a/monai/apps/tcia/__init__.py b/monai/apps/tcia/__init__.py index af3d44fd14..e33d4abfbc 100644 --- a/monai/apps/tcia/__init__.py +++ b/monai/apps/tcia/__init__.py @@ -12,4 +12,11 @@ from __future__ import annotations from .label_desc import TCIA_LABEL_DICT -from .utils import download_tcia_series_instance, get_tcia_metadata, get_tcia_ref_uid, match_tcia_ref_uid_in_study +from .utils import ( + BASE_URL, + DCM_FILENAME_REGEX, + download_tcia_series_instance, + get_tcia_metadata, + get_tcia_ref_uid, + match_tcia_ref_uid_in_study, +) diff --git a/monai/apps/tcia/utils.py b/monai/apps/tcia/utils.py index 9c120f0072..5524b488e9 100644 --- a/monai/apps/tcia/utils.py +++ b/monai/apps/tcia/utils.py @@ -21,10 +21,18 @@ requests_get, has_requests = optional_import("requests", name="get") pd, has_pandas = optional_import("pandas") -__all__ = ["get_tcia_metadata", "download_tcia_series_instance", "get_tcia_ref_uid", "match_tcia_ref_uid_in_study"] - +DCM_FILENAME_REGEX = r"^(?!.*LICENSE).*" # excluding the file with "LICENSE" in its name BASE_URL = "https://services.cancerimagingarchive.net/nbia-api/services/v1/" +__all__ = [ + "get_tcia_metadata", + "download_tcia_series_instance", + "get_tcia_ref_uid", + "match_tcia_ref_uid_in_study", + "DCM_FILENAME_REGEX", + "BASE_URL", +] + def get_tcia_metadata(query: str, attribute: str | None = None) -> list: """ diff --git a/monai/data/image_reader.py b/monai/data/image_reader.py index 4c7f2c8c3b..fe199d9570 100644 --- a/monai/data/image_reader.py +++ b/monai/data/image_reader.py @@ -13,6 +13,7 @@ import glob import os +import re import warnings from abc import ABC, abstractmethod from collections.abc import Callable, Iterable, Iterator, Sequence @@ -403,8 +404,12 @@ class PydicomReader(ImageReader): label_dict: label of the dicom data. If provided, it will be used when loading segmentation data. Keys of the dict are the classes, and values are the corresponding class number. For example: for TCIA collection "C4KC-KiTS", it can be: {"Kidney": 0, "Renal Tumor": 1}. + fname_regex: a regular expression to match the file names when the input is a folder. + If provided, only the matched files will be included. For example, to include the file name + "image_0001.dcm", the regular expression could be `".*image_(\\d+).dcm"`. Default to `""`. + Set it to `None` to use `pydicom.misc.is_dicom` to match valid files. kwargs: additional args for `pydicom.dcmread` API. more details about available args: - https://pydicom.github.io/pydicom/stable/reference/generated/pydicom.filereader.dcmread.html#pydicom.filereader.dcmread + https://pydicom.github.io/pydicom/stable/reference/generated/pydicom.filereader.dcmread.html If the `get_data` function will be called (for example, when using this reader with `monai.transforms.LoadImage`), please ensure that the argument `stop_before_pixels` is `True`, and `specific_tags` covers all necessary tags, such as `PixelSpacing`, @@ -418,6 +423,7 @@ def __init__( swap_ij: bool = True, prune_metadata: bool = True, label_dict: dict | None = None, + fname_regex: str = "", **kwargs, ): super().__init__() @@ -427,6 +433,7 @@ def __init__( self.swap_ij = swap_ij self.prune_metadata = prune_metadata self.label_dict = label_dict + self.fname_regex = fname_regex def verify_suffix(self, filename: Sequence[PathLike] | PathLike) -> bool: """ @@ -467,9 +474,16 @@ def read(self, data: Sequence[PathLike] | PathLike, **kwargs): name = f"{name}" if Path(name).is_dir(): # read DICOM series - series_slcs = glob.glob(os.path.join(name, "*")) - series_slcs = [slc for slc in series_slcs if "LICENSE" not in slc] - slices = [pydicom.dcmread(fp=slc, **kwargs_) for slc in series_slcs] + if self.fname_regex is not None: + series_slcs = [slc for slc in glob.glob(os.path.join(name, "*")) if re.match(self.fname_regex, slc)] + else: + series_slcs = [slc for slc in glob.glob(os.path.join(name, "*")) if pydicom.misc.is_dicom(slc)] + slices = [] + for slc in series_slcs: + try: + slices.append(pydicom.dcmread(fp=slc, **kwargs_)) + except pydicom.errors.InvalidDicomError as e: + warnings.warn(f"Failed to read {slc} with exception: \n{e}.", stacklevel=2) img_.append(slices if len(slices) > 1 else slices[0]) if len(slices) > 1: self.has_series = True diff --git a/tests/test_load_image.py b/tests/test_load_image.py index 6f29e7ac50..b6a10bceb4 100644 --- a/tests/test_load_image.py +++ b/tests/test_load_image.py @@ -226,6 +226,11 @@ def test_itk_dicom_series_reader(self, input_param, filenames, expected_shape, e ) self.assertTupleEqual(result.shape, expected_np_shape) + def test_no_files(self): + with self.assertRaisesRegex(RuntimeError, "list index out of range"): # fname_regex excludes everything + LoadImage(image_only=True, reader="PydicomReader", fname_regex=r"^(?!.*).*")("tests/testing_data/CT_DICOM") + LoadImage(image_only=True, reader="PydicomReader", fname_regex=None)("tests/testing_data/CT_DICOM") + def test_itk_dicom_series_reader_single(self): result = LoadImage(image_only=True, reader="ITKReader")(self.data_dir) self.assertEqual(result.meta["filename_or_obj"], f"{Path(self.data_dir)}") diff --git a/tests/test_tciadataset.py b/tests/test_tciadataset.py index 7a14262587..2a3928f9aa 100644 --- a/tests/test_tciadataset.py +++ b/tests/test_tciadataset.py @@ -16,7 +16,7 @@ import unittest from monai.apps import TciaDataset -from monai.apps.tcia import TCIA_LABEL_DICT +from monai.apps.tcia import DCM_FILENAME_REGEX, TCIA_LABEL_DICT from monai.data import MetaTensor from monai.transforms import Compose, EnsureChannelFirstd, LoadImaged, ScaleIntensityd from tests.utils import skip_if_downloading_fails, skip_if_quick @@ -32,7 +32,12 @@ def test_values(self): transform = Compose( [ - LoadImaged(keys=["image", "seg"], reader="PydicomReader", label_dict=TCIA_LABEL_DICT[collection]), + LoadImaged( + keys=["image", "seg"], + reader="PydicomReader", + fname_regex=DCM_FILENAME_REGEX, + label_dict=TCIA_LABEL_DICT[collection], + ), EnsureChannelFirstd(keys="image", channel_dim="no_channel"), ScaleIntensityd(keys="image"), ] @@ -82,10 +87,24 @@ def _test_dataset(dataset): self.assertTupleEqual(data[0]["image"].shape, (256, 256, 24)) self.assertEqual(len(data), int(download_len * val_frac)) data = TciaDataset( - root_dir=testing_dir, collection=collection, section="validation", download=False, val_frac=val_frac + root_dir=testing_dir, + collection=collection, + section="validation", + download=False, + fname_regex=DCM_FILENAME_REGEX, + val_frac=val_frac, ) self.assertTupleEqual(data[0]["image"].shape, (256, 256, 24)) self.assertEqual(len(data), download_len) + with self.assertWarns(UserWarning): + data = TciaDataset( + root_dir=testing_dir, + collection=collection, + section="validation", + fname_regex=".*", # all files including 'LICENSE' is not a valid input + download=False, + val_frac=val_frac, + )[0] shutil.rmtree(os.path.join(testing_dir, collection)) try: