From de31e4b8bf9b4a7e0668d19059a5ac4760dceee1 Mon Sep 17 00:00:00 2001 From: vfdev Date: Mon, 25 Apr 2022 17:28:46 +0200 Subject: [PATCH] Added `resized_crop_segmentation_mask` op (#5855) * [proto] Added crop_bounding_box op * Added `crop_segmentation_mask` op * Fixed failed mypy * Added tests for resized_crop_bounding_box * Fixed code formatting * Added resized_crop_segmentation_mask op * Added tests --- test/test_prototype_transforms_functional.py | 33 +++++++++++++++++++ .../transforms/functional/__init__.py | 3 +- .../transforms/functional/_geometry.py | 12 +++++++ 3 files changed, 47 insertions(+), 1 deletion(-) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 2da3aa4696a..36d1677ede5 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -362,6 +362,14 @@ def resized_crop_bounding_box(): ) +@register_kernel_info_from_sample_inputs_fn +def resized_crop_segmentation_mask(): + for mask, top, left, height, width, size in itertools.product( + make_segmentation_masks(), [-8, 0, 9], [-8, 0, 9], [12, 20], [12, 20], [(32, 32), (16, 18)] + ): + yield SampleInput(mask, top=top, left=left, height=height, width=width, size=size) + + @pytest.mark.parametrize( "kernel", [ @@ -998,3 +1006,28 @@ def _compute_expected(bbox, top_, left_, height_, width_, size_): output_boxes = convert_bounding_box_format(output_boxes, format, features.BoundingBoxFormat.XYXY) torch.testing.assert_close(output_boxes, expected_bboxes) + + +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize( + "top, left, height, width, size", + [ + [0, 0, 30, 30, (60, 60)], + [5, 5, 35, 45, (32, 34)], + ], +) +def test_correctness_resized_crop_segmentation_mask(device, top, left, height, width, size): + def _compute_expected(mask, top_, left_, height_, width_, size_): + output = mask.clone() + output = output[:, top_ : top_ + height_, left_ : left_ + width_] + output = torch.nn.functional.interpolate(output[None, :].float(), size=size_, mode="nearest") + output = output[0, :].long() + return output + + in_mask = torch.zeros(1, 100, 100, dtype=torch.long, device=device) + in_mask[0, 10:20, 10:20] = 1 + in_mask[0, 5:15, 12:23] = 2 + + expected_mask = _compute_expected(in_mask, top, left, height, width, size) + output_mask = F.resized_crop_segmentation_mask(in_mask, top, left, height, width, size) + torch.testing.assert_close(output_mask, expected_mask) diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index 7069f17c414..dfbc81baea3 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -47,9 +47,10 @@ resize_segmentation_mask, center_crop_image_tensor, center_crop_image_pil, + resized_crop_bounding_box, resized_crop_image_tensor, resized_crop_image_pil, - resized_crop_bounding_box, + resized_crop_segmentation_mask, affine_bounding_box, affine_image_tensor, affine_image_pil, diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index fc1eddfd230..5f9e77fdbf4 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -555,6 +555,18 @@ def resized_crop_bounding_box( return resize_bounding_box(bounding_box, size, (height, width)) +def resized_crop_segmentation_mask( + mask: torch.Tensor, + top: int, + left: int, + height: int, + width: int, + size: List[int], +) -> torch.Tensor: + mask = crop_segmentation_mask(mask, top, left, height, width) + return resize_segmentation_mask(mask, size) + + def _parse_five_crop_size(size: List[int]) -> List[int]: if isinstance(size, numbers.Number): size = [int(size), int(size)]