From 56ee32e36c5c0c7a5cb10afa4ec5589c81171e6b Mon Sep 17 00:00:00 2001 From: David Carreto Fidalgo Date: Sat, 3 Aug 2024 22:29:45 +0200 Subject: [PATCH] Fix: Small logic mistake in the `AsDiscrete.__call__` method (#7984) Hi MONAI Team! Thank you very much for this super nice framework, really appreciate it! Just found a small logic mistake in one of the transform classes. To reproduce: ```python import torch from monai.transforms.post.array import AsDiscrete transform = AsDiscrete(argmax=True) prediction = torch.rand(2, 3, 3) transform(prediction, argmax=False) # will still apply argmax ``` ### Description Proposed fix: `argmax` is explicitly checked for `None` in the `__cal__` method. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: David Carreto Fidalgo Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/transforms/post/array.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/monai/transforms/post/array.py b/monai/transforms/post/array.py index da9b23ce57..2e733c4f6c 100644 --- a/monai/transforms/post/array.py +++ b/monai/transforms/post/array.py @@ -211,7 +211,8 @@ def __call__( raise ValueError("`to_onehot=True/False` is deprecated, please use `to_onehot=num_classes` instead.") img = convert_to_tensor(img, track_meta=get_track_meta()) img_t, *_ = convert_data_type(img, torch.Tensor) - if argmax or self.argmax: + argmax = self.argmax if argmax is None else argmax + if argmax: img_t = torch.argmax(img_t, dim=self.kwargs.get("dim", 0), keepdim=self.kwargs.get("keepdim", True)) to_onehot = self.to_onehot if to_onehot is None else to_onehot