diff --git a/monai/data/wsi_datasets.py b/monai/data/wsi_datasets.py index 3488029a7a..2ee8c9d363 100644 --- a/monai/data/wsi_datasets.py +++ b/monai/data/wsi_datasets.py @@ -23,7 +23,7 @@ from monai.data.utils import iter_patch_position from monai.data.wsi_reader import BaseWSIReader, WSIReader from monai.transforms import ForegroundMask, Randomizable, apply_transform -from monai.utils import convert_to_dst_type, ensure_tuple_rep +from monai.utils import convert_to_dst_type, ensure_tuple, ensure_tuple_rep from monai.utils.enums import CommonKeys, ProbMapKeys, WSIPatchKeys __all__ = ["PatchWSIDataset", "SlidingPatchWSIDataset", "MaskedPatchWSIDataset"] @@ -123,9 +123,9 @@ def _get_label(self, sample: dict): def _get_location(self, sample: dict): if self.center_location: size = self._get_size(sample) - return [sample[WSIPatchKeys.LOCATION][i] - size[i] // 2 for i in range(len(size))] + return ensure_tuple(sample[WSIPatchKeys.LOCATION][i] - size[i] // 2 for i in range(len(size))) else: - return sample[WSIPatchKeys.LOCATION] + return ensure_tuple(sample[WSIPatchKeys.LOCATION]) def _get_level(self, sample: dict): if self.patch_level is None: