From ea46973a1e87c8d081f28e7dc357df2454c86687 Mon Sep 17 00:00:00 2001 From: David Carreto Fidalgo Date: Fri, 2 Aug 2024 00:05:00 +0200 Subject: [PATCH] Fix AsDiscrete.__call__ function when self.argmax is True and argmax is False Signed-off-by: David Carreto Fidalgo --- monai/transforms/post/array.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index da9b23ce57..2e733c4f6c 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -211,7 +211,8 @@ def __call__( raise ValueError("`to_onehot=True/False` is deprecated, please use `to_onehot=num_classes` instead.") img = convert_to_tensor(img, track_meta=get_track_meta()) img_t, *_ = convert_data_type(img, torch.Tensor) - if argmax or self.argmax: + argmax = self.argmax if argmax is None else argmax + if argmax: img_t = torch.argmax(img_t, dim=self.kwargs.get("dim", 0), keepdim=self.kwargs.get("keepdim", True)) to_onehot = self.to_onehot if to_onehot is None else to_onehot