From 334359ed3d2b1f4110634ad358d07c02a4280df7 Mon Sep 17 00:00:00 2001 From: KumoLiu Date: Fri, 14 Jul 2023 16:27:18 +0800 Subject: [PATCH] fix #6717 Signed-off-by: KumoLiu --- monai/transforms/utils.py | 6 +++--- tests/test_invert.py | 2 +- tests/test_invertd.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 9b24d7038a..a6c60052e1 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -1398,10 +1398,10 @@ def convert_applied_interp_mode(trans_info, mode: str = "nearest", align_corners trans_info = dict(trans_info) if "mode" in trans_info: current_mode = trans_info["mode"] - if current_mode[0] in _interp_modes: - trans_info["mode"] = [mode for _ in range(len(mode))] - elif current_mode in _interp_modes: + if isinstance(current_mode, int) or current_mode in _interp_modes: trans_info["mode"] = mode + elif isinstance(current_mode[0], int) or current_mode[0] in _interp_modes: + trans_info["mode"] = [mode for _ in range(len(mode))] if "align_corners" in trans_info: _align_corners = TraceKeys.NONE if align_corners is None else align_corners current_value = trans_info["align_corners"] diff --git a/tests/test_invert.py b/tests/test_invert.py index b7c11362ce..9c57b11331 100644 --- a/tests/test_invert.py +++ b/tests/test_invert.py @@ -50,7 +50,7 @@ def test_invert(self): LoadImage(image_only=True), EnsureChannelFirst(), Orientation("RPS"), - Spacing(pixdim=(1.2, 1.01, 0.9), mode="bilinear", dtype=np.float32), + Spacing(pixdim=(1.2, 1.01, 0.9), mode=1, dtype=np.float32), RandFlip(prob=0.5, spatial_axis=[1, 2]), RandAxisFlip(prob=0.5), RandRotate90(prob=0, spatial_axes=(1, 2)), diff --git a/tests/test_invertd.py b/tests/test_invertd.py index 5ad8735fdc..cd2e91257a 100644 --- a/tests/test_invertd.py +++ b/tests/test_invertd.py @@ -58,7 +58,7 @@ def test_invert(self): RandRotate90d(KEYS, prob=0, spatial_axes=(1, 2)), RandZoomd(KEYS, prob=0.5, min_zoom=0.5, max_zoom=1.1, keep_size=True), RandRotated(KEYS, prob=0.5, range_x=np.pi, mode="bilinear", align_corners=True, dtype=np.float64), - RandAffined(KEYS, prob=0.5, rotate_range=np.pi, mode="nearest"), + RandAffined(KEYS, prob=0.5, rotate_range=np.pi, mode=["nearest", 0]), ResizeWithPadOrCropd(KEYS, 100), CastToTyped(KEYS, dtype=[torch.uint8, np.uint8]), CopyItemsd("label", times=2, names=["label_inverted", "label_inverted1"]),