Skip to content

Commit

Permalink
Improve GPU memory efficiency on sliding-window inferer (#6140)
Browse files Browse the repository at this point in the history
### Description
Improve GPU memory efficiency on sliding-window inferer.

A few sentences describing the changes proposed in this pull request.

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: dongy <[email protected]>
Co-authored-by: dongy <[email protected]>
  • Loading branch information
dongyang0122 and dongy authored Mar 14, 2023
1 parent 0a904fb commit 589c711
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions monai/inferers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,9 @@ def sliding_window_inference(
diff = max(roi_size[k - 2] - inputs.shape[k], 0)
half = diff // 2
pad_size.extend([half, diff - half])
inputs = F.pad(inputs, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode), value=cval)

if max(pad_size) > 0:
inputs = F.pad(inputs, pad=pad_size, mode=look_up_option(padding_mode, PytorchPadMode), value=cval)

scan_interval = _get_scan_interval(image_size, roi_size, num_spatial_dims, overlap)

Expand All @@ -167,8 +169,8 @@ def sliding_window_inference(
importance_map_ = convert_data_type(importance_map_, torch.Tensor, device, compute_dtype)[0]

# handle non-positive weights
min_non_zero = max(importance_map_[importance_map_ != 0].min().item(), 1e-3)
importance_map_ = torch.clamp(importance_map_.to(torch.float32), min=min_non_zero).to(compute_dtype)
min_non_zero = max(torch.min(importance_map_).item(), 1e-3)
importance_map_ = torch.clamp_(importance_map_.to(torch.float32), min=min_non_zero).to(compute_dtype)

# Perform predictions
dict_key, output_image_list, count_map_list = None, [], []
Expand Down

0 comments on commit 589c711

Please sign in to comment.