Skip to content

Commit

Permalink
fixes integration test and 6354 (#6353)
Browse files Browse the repository at this point in the history
- skip calls to torch.cuda.mem_get_info 
- fixes #6354 


### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [x] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [x] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: Wenqi Li <[email protected]>
  • Loading branch information
wyli authored Apr 13, 2023
1 parent 1a55ba5 commit d8d887f
Show file tree
Hide file tree
Showing 8 changed files with 37 additions and 18 deletions.
1 change: 1 addition & 0 deletions .github/workflows/pythonapp-gpu.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ jobs:
- "PT110+CUDA111"
- "PT112+CUDA118DOCKER"
- "PT113+CUDA116"
- "PT114+CUDA120DOCKER"
include:
# https://docs.nvidia.com/deeplearning/frameworks/pytorch-release-notes
- environment: PT19+CUDA114DOCKER
Expand Down
6 changes: 3 additions & 3 deletions monai/metrics/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,18 +309,18 @@ def prepare_spacing(
return list([spacing] * batch_size)
elif isinstance(spacing, (Sequence, np.ndarray)):
assert all(
[isinstance(s, type(spacing[0])) for s in list(spacing)]
isinstance(s, type(spacing[0])) for s in list(spacing)
), "if `spacing` is a sequence, its elements should be of same type."

if isinstance(spacing[0], (Sequence, np.ndarray)):
assert (
len(spacing) == batch_size
), "if `spacing` is a sequence of sequences, the outer sequence should have same length as batch size."
assert all(
[len(s) == img_dim for s in list(spacing)]
len(s) == img_dim for s in list(spacing)
), "each element of `spacing` list should either have same length as image dim."
assert all(
[isinstance(i, (int, float)) for s in list(spacing) for i in list(s)]
isinstance(i, (int, float)) for s in list(spacing) for i in list(s)
), "if `spacing` is a sequence of sequences or 2D np.ndarray, the elements should be integers or floats."
return list(spacing)
elif isinstance(spacing[0], (int, float)):
Expand Down
2 changes: 1 addition & 1 deletion monai/networks/blocks/backbone_fpn_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def _resnet_fpn_extractor(
if trainable_layers == 5:
layers_to_train.append("bn1")
for name, parameter in backbone.named_parameters():
if all([not name.startswith(layer) for layer in layers_to_train]):
if all(not name.startswith(layer) for layer in layers_to_train):
parameter.requires_grad_(False)

if extra_blocks is None:
Expand Down
1 change: 1 addition & 0 deletions monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,6 +624,7 @@
Fourier,
allow_missing_keys_mode,
attach_hook,
check_non_lazy_pending_ops,
compute_divisible_spatial_size,
convert_applied_interp_mode,
convert_pad_mode,
Expand Down
12 changes: 0 additions & 12 deletions monai/transforms/croppad/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -780,8 +780,6 @@ def compute_bounding_box(self, img: torch.Tensor) -> tuple[np.ndarray, np.ndarra
And adjust bounding box coords to be divisible by `k`.
"""
if isinstance(img, MetaTensor) and img.pending_operations:
warnings.warn("foreground computation may not be accurate if the image has pending operations.")
box_start, box_end = generate_spatial_bounding_box(
img, self.select_fn, self.channel_indices, self.margin, self.allow_smaller
)
Expand Down Expand Up @@ -869,8 +867,6 @@ def __init__(
self.centers: list[np.ndarray] = []

def randomize(self, weight_map: NdarrayOrTensor) -> None:
if isinstance(weight_map, MetaTensor) and weight_map.pending_operations:
warnings.warn("weight map has pending operations, the sampling may not be correct.")
self.centers = weighted_patch_samples(
spatial_size=self.spatial_size, w=weight_map[0], n_samples=self.num_samples, r_state=self.R
) # using only the first channel as weight map
Expand Down Expand Up @@ -1015,10 +1011,6 @@ def randomize(
fg_indices_ = self.fg_indices if fg_indices is None else fg_indices
bg_indices_ = self.bg_indices if bg_indices is None else bg_indices
if fg_indices_ is None or bg_indices_ is None:
if isinstance(label, MetaTensor) and label.pending_operations:
warnings.warn("label has pending operations, the fg/bg indices may be incorrect.")
if isinstance(image, MetaTensor) and image.pending_operations:
warnings.warn("image has pending operations, the fg/bg indices may be incorrect.")
if label is None:
raise ValueError("label must be provided.")
fg_indices_, bg_indices_ = map_binary_to_indices(label, image, self.image_threshold)
Expand Down Expand Up @@ -1195,10 +1187,6 @@ def randomize(
) -> None:
indices_ = self.indices if indices is None else indices
if indices_ is None:
if isinstance(label, MetaTensor) and label.pending_operations:
warnings.warn("label has pending operations, the fg/bg indices may be incorrect.")
if isinstance(image, MetaTensor) and image.pending_operations:
warnings.warn("image has pending operations, the fg/bg indices may be incorrect.")
if label is None:
raise ValueError("label must not be None.")
indices_ = map_classes_to_indices(
Expand Down
30 changes: 29 additions & 1 deletion monai/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
"compute_divisible_spatial_size",
"convert_applied_interp_mode",
"copypaste_arrays",
"check_non_lazy_pending_ops",
"create_control_grid",
"create_grid",
"create_rotate",
Expand Down Expand Up @@ -294,6 +295,27 @@ def resize_center(img: np.ndarray, *resize_dims: int | None, fill_value: float =
return img[srcslices]


def check_non_lazy_pending_ops(
input_array: NdarrayOrTensor, name: None | str = None, raise_error: bool = False
) -> None:
"""
Check whether the input array has pending operations, raise an error or warn when it has.
Args:
input_array: input array to be checked.
name: an optional name to be included in the error message.
raise_error: whether to raise an error, default to False, a warning message will be issued instead.
"""
if isinstance(input_array, monai.data.MetaTensor) and input_array.pending_operations:
msg = (
"The input image is a MetaTensor and has pending operations,\n"
f"but the function {name or ''} assumes non-lazy input, result may be incorrect."
)
if raise_error:
raise ValueError(msg)
warnings.warn(msg)


def map_binary_to_indices(
label: NdarrayOrTensor, image: NdarrayOrTensor | None = None, image_threshold: float = 0.0
) -> tuple[NdarrayOrTensor, NdarrayOrTensor]:
Expand All @@ -310,13 +332,14 @@ def map_binary_to_indices(
image_threshold: if enabled `image`, use ``image > image_threshold`` to
determine the valid image content area and select background only in this area.
"""

check_non_lazy_pending_ops(label, name="map_binary_to_indices")
# Prepare fg/bg indices
if label.shape[0] > 1:
label = label[1:] # for One-Hot format data, remove the background channel
label_flat = ravel(any_np_pt(label, 0)) # in case label has multiple dimensions
fg_indices = nonzero(label_flat)
if image is not None:
check_non_lazy_pending_ops(image, name="map_binary_to_indices")
img_flat = ravel(any_np_pt(image > image_threshold, 0))
img_flat, *_ = convert_to_dst_type(img_flat, label, dtype=bool)
bg_indices = nonzero(img_flat & ~label_flat)
Expand Down Expand Up @@ -357,8 +380,10 @@ def map_classes_to_indices(
Default is None, no subsampling.
"""
check_non_lazy_pending_ops(label, name="map_classes_to_indices")
img_flat: NdarrayOrTensor | None = None
if image is not None:
check_non_lazy_pending_ops(image, name="map_classes_to_indices")
img_flat = ravel((image > image_threshold).any(0))

# assuming the first dimension is channel
Expand Down Expand Up @@ -410,6 +435,7 @@ def weighted_patch_samples(
a list of `n_samples` N-D integers representing the spatial sampling location of patches.
"""
check_non_lazy_pending_ops(w, name="weighted_patch_samples")
if w is None:
raise ValueError("w must be an ND array, got None.")
if r_state is None:
Expand Down Expand Up @@ -937,6 +963,7 @@ def generate_spatial_bounding_box(
allow_smaller: when computing box size with `margin`, whether allow the image size to be smaller
than box size, default to `True`.
"""
check_non_lazy_pending_ops(img, name="generate_spatial_bounding_box")
spatial_size = img.shape[1:]
data = img[list(ensure_tuple(channel_indices))] if channel_indices is not None else img
data = select_fn(data).any(0)
Expand Down Expand Up @@ -1175,6 +1202,7 @@ def get_extreme_points(
Raises:
ValueError: When the input image does not have any foreground pixel.
"""
check_non_lazy_pending_ops(img, name="get_extreme_points")
if rand_state is None:
rand_state = np.random.random.__self__ # type: ignore
indices = where(img != background)
Expand Down
1 change: 1 addition & 0 deletions tests/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ def get_default_pattern(loader):
cases.append(f"tests.{test_module}")
else:
print(f"monai test runner: excluding tests.{test_module}")
print(cases)
tests = unittest.TestLoader().loadTestsFromNames(cases)
discovery_time = pc.total_time
print(f"time to discover tests: {discovery_time}s, total cases: {tests.countTestCases()}.")
Expand Down
2 changes: 1 addition & 1 deletion tests/test_integration_gpu_customization.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@


@skip_if_quick
@SkipIfBeforePyTorchVersion((1, 9, 1))
@SkipIfBeforePyTorchVersion((1, 11, 1)) # module 'torch.cuda' has no attribute 'mem_get_info'
@unittest.skipIf(not has_tb, "no tensorboard summary writer")
class TestEnsembleGpuCustomization(unittest.TestCase):
def setUp(self) -> None:
Expand Down

0 comments on commit d8d887f

Please sign in to comment.