From 589c711833f72309cd6a0641a4b16dc91a20b591 Mon Sep 17 00:00:00 2001 From: Dong Yang Date: Tue, 14 Mar 2023 10:07:54 -0600 Subject: [PATCH] Improve GPU memory efficiency on sliding-window inferer (#6140) ### Description Improve GPU memory efficiency on sliding-window inferer. A few sentences describing the changes proposed in this pull request. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: dongy Co-authored-by: dongy --- monai/inferers/utils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/monai/inferers/utils.py b/monai/inferers/utils.py index 882f0f9101..40ef3eab2d 100644 --- a/monai/inferers/utils.py +++ b/monai/inferers/utils.py @@ -142,7 +142,9 @@ def sliding_window_inference( diff = max(roi_size[k - 2] - inputs.shape[k], 0) half = diff // 2 pad_size.extend([half, diff - half]) - inputs = F.pad(inputs, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode), value=cval) + + if max(pad_size) > 0: + inputs = F.pad(inputs, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode), value=cval) scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims, overlap) @@ -167,8 +169,8 @@ def sliding_window_inference( importance_map_ = convert_data_type(importance_map_, torch.Tensor, device, compute_dtype)[0] # handle non-positive weights - min_non_zero = max(importance_map_[importance_map_ != 0].min().item(), 1e-3) - importance_map_ = torch.clamp(importance_map_.to(torch.float32), min=min_non_zero).to(compute_dtype) + min_non_zero = max(torch.min(importance_map_).item(), 1e-3) + importance_map_ = torch.clamp_(importance_map_.to(torch.float32), min=min_non_zero).to(compute_dtype) # Perform predictions dict_key, output_image_list, count_map_list = None, [], []