From d64f954ee66db8723a9eabf7e600e9f95dcecaed Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Fri, 9 Aug 2024 11:00:11 +0800 Subject: [PATCH] fix #8006 Signed-off-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/data/wsi_datasets.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/monai/data/wsi_datasets.py b/monai/data/wsi_datasets.py index 3488029a7a..2ee8c9d363 100644 --- a/monai/data/wsi_datasets.py +++ b/monai/data/wsi_datasets.py @@ -23,7 +23,7 @@ from monai.data.utils import iter_patch_position from monai.data.wsi_reader import BaseWSIReader, WSIReader from monai.transforms import ForegroundMask, Randomizable, apply_transform -from monai.utils import convert_to_dst_type, ensure_tuple_rep +from monai.utils import convert_to_dst_type, ensure_tuple, ensure_tuple_rep from monai.utils.enums import CommonKeys, ProbMapKeys, WSIPatchKeys __all__ = ["PatchWSIDataset", "SlidingPatchWSIDataset", "MaskedPatchWSIDataset"] @@ -123,9 +123,9 @@ def _get_label(self, sample: dict): def _get_location(self, sample: dict): if self.center_location: size = self._get_size(sample) - return [sample[WSIPatchKeys.LOCATION][i] - size[i] // 2 for i in range(len(size))] + return ensure_tuple(sample[WSIPatchKeys.LOCATION][i] - size[i] // 2 for i in range(len(size))) else: - return sample[WSIPatchKeys.LOCATION] + return ensure_tuple(sample[WSIPatchKeys.LOCATION]) def _get_level(self, sample: dict): if self.patch_level is None: