Skip to content

Commit

Permalink
Added resized_crop_segmentation_mask op (#5855)
Browse files Browse the repository at this point in the history
* [proto] Added crop_bounding_box op

* Added `crop_segmentation_mask` op

* Fixed failed mypy

* Added tests for resized_crop_bounding_box

* Fixed code formatting

* Added resized_crop_segmentation_mask op

* Added tests
  • Loading branch information
vfdev-5 authored Apr 25, 2022
1 parent 6d85d74 commit de31e4b
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 1 deletion.
33 changes: 33 additions & 0 deletions test/test_prototype_transforms_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,14 @@ def resized_crop_bounding_box():
)


@register_kernel_info_from_sample_inputs_fn
def resized_crop_segmentation_mask():
for mask, top, left, height, width, size in itertools.product(
make_segmentation_masks(), [-8, 0, 9], [-8, 0, 9], [12, 20], [12, 20], [(32, 32), (16, 18)]
):
yield SampleInput(mask, top=top, left=left, height=height, width=width, size=size)


@pytest.mark.parametrize(
"kernel",
[
Expand Down Expand Up @@ -998,3 +1006,28 @@ def _compute_expected(bbox, top_, left_, height_, width_, size_):
output_boxes = convert_bounding_box_format(output_boxes, format, features.BoundingBoxFormat.XYXY)

torch.testing.assert_close(output_boxes, expected_bboxes)


@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize(
"top, left, height, width, size",
[
[0, 0, 30, 30, (60, 60)],
[5, 5, 35, 45, (32, 34)],
],
)
def test_correctness_resized_crop_segmentation_mask(device, top, left, height, width, size):
def _compute_expected(mask, top_, left_, height_, width_, size_):
output = mask.clone()
output = output[:, top_ : top_ + height_, left_ : left_ + width_]
output = torch.nn.functional.interpolate(output[None, :].float(), size=size_, mode="nearest")
output = output[0, :].long()
return output

in_mask = torch.zeros(1, 100, 100, dtype=torch.long, device=device)
in_mask[0, 10:20, 10:20] = 1
in_mask[0, 5:15, 12:23] = 2

expected_mask = _compute_expected(in_mask, top, left, height, width, size)
output_mask = F.resized_crop_segmentation_mask(in_mask, top, left, height, width, size)
torch.testing.assert_close(output_mask, expected_mask)
3 changes: 2 additions & 1 deletion torchvision/prototype/transforms/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,10 @@
resize_segmentation_mask,
center_crop_image_tensor,
center_crop_image_pil,
resized_crop_bounding_box,
resized_crop_image_tensor,
resized_crop_image_pil,
resized_crop_bounding_box,
resized_crop_segmentation_mask,
affine_bounding_box,
affine_image_tensor,
affine_image_pil,
Expand Down
12 changes: 12 additions & 0 deletions torchvision/prototype/transforms/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,18 @@ def resized_crop_bounding_box(
return resize_bounding_box(bounding_box, size, (height, width))


def resized_crop_segmentation_mask(
mask: torch.Tensor,
top: int,
left: int,
height: int,
width: int,
size: List[int],
) -> torch.Tensor:
mask = crop_segmentation_mask(mask, top, left, height, width)
return resize_segmentation_mask(mask, size)


def _parse_five_crop_size(size: List[int]) -> List[int]:
if isinstance(size, numbers.Number):
size = [int(size), int(size)]
Expand Down

0 comments on commit de31e4b

Please sign in to comment.