Skip to content

Commit

Permalink
Cleanup functional_tensor.py (#3159) (#3171)
Browse files Browse the repository at this point in the history
* added the helper method for dimension checks

* unit tests for dimensio check function in functional_tensor

* code formatting and typing

* moved torch image check after tensor check

* unit testcases for test_assert_image_tensor added and refactored

* separate unit testcase file deleted

* assert_image_tensor added to newly created 6 methods

* test cases added for new 6 mthohds

* removed wrongly pasted posterize method and added solarize method for testing

Co-authored-by: Vasilis Vryniotis <[email protected]>

Co-authored-by: Vasilis Vryniotis <[email protected]>
  • Loading branch information
avijit9 and datumbox authored Dec 15, 2020
1 parent 90645cc commit 1a300d8
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 44 deletions.
24 changes: 24 additions & 0 deletions test/test_functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

from common_utils import TransformsTester

from typing import Dict, List, Tuple


NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC

Expand All @@ -34,6 +36,28 @@ def _test_fn_on_batch(self, batch_tensors, fn, **fn_kwargs):
s_transformed_batch = scripted_fn(batch_tensors, **fn_kwargs)
self.assertTrue(transformed_batch.allclose(s_transformed_batch))

def test_assert_image_tensor(self):
shape = (100,)
tensor = torch.rand(*shape, dtype=torch.float, device=self.device)

list_of_methods = [(F_t._get_image_size, (tensor, )), (F_t.vflip, (tensor, )),
(F_t.hflip, (tensor, )), (F_t.crop, (tensor, 1, 2, 4, 5)),
(F_t.adjust_brightness, (tensor, 0.)), (F_t.adjust_contrast, (tensor, 1.)),
(F_t.adjust_hue, (tensor, -0.5)), (F_t.adjust_saturation, (tensor, 2.)),
(F_t.center_crop, (tensor, [10, 11])), (F_t.five_crop, (tensor, [10, 11])),
(F_t.ten_crop, (tensor, [10, 11])), (F_t.pad, (tensor, [2, ], 2, "constant")),
(F_t.resize, (tensor, [10, 11])), (F_t.perspective, (tensor, [0.2, ])),
(F_t.gaussian_blur, (tensor, (2, 2), (0.7, 0.5))),
(F_t.invert, (tensor, )), (F_t.posterize, (tensor, 0)),
(F_t.solarize, (tensor, 0.3)), (F_t.adjust_sharpness, (tensor, 0.3)),
(F_t.autocontrast, (tensor, )), (F_t.equalize, (tensor, ))]

for func, args in list_of_methods:
with self.assertRaises(Exception) as context:
func(*args)

self.assertTrue('Tensor is not a torch image.' in str(context.exception))

def test_vflip(self):
script_vflip = torch.jit.script(F.vflip)

Expand Down
91 changes: 47 additions & 44 deletions torchvision/transforms/functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,15 @@ def _is_tensor_a_torch_image(x: Tensor) -> bool:
return x.ndim >= 2


def _assert_image_tensor(img):
if not _is_tensor_a_torch_image(img):
raise TypeError("Tensor is not a torch image.")


def _get_image_size(img: Tensor) -> List[int]:
"""Returns (w, h) of tensor image"""
if _is_tensor_a_torch_image(img):
return [img.shape[-1], img.shape[-2]]
raise TypeError("Unexpected input type")
_assert_image_tensor(img)
return [img.shape[-1], img.shape[-2]]


def _get_image_num_channels(img: Tensor) -> int:
Expand Down Expand Up @@ -143,8 +147,7 @@ def vflip(img: Tensor) -> Tensor:
Returns:
Tensor: Vertically flipped image Tensor.
"""
if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')
_assert_image_tensor(img)

return img.flip(-2)

Expand All @@ -163,8 +166,7 @@ def hflip(img: Tensor) -> Tensor:
Returns:
Tensor: Horizontally flipped image Tensor.
"""
if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')
_assert_image_tensor(img)

return img.flip(-1)

Expand All @@ -187,8 +189,7 @@ def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
Returns:
Tensor: Cropped image.
"""
if not _is_tensor_a_torch_image(img):
raise TypeError("tensor is not a torch image.")
_assert_image_tensor(img)

return img[..., top:top + height, left:left + width]

Expand Down Expand Up @@ -254,8 +255,7 @@ def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
if brightness_factor < 0:
raise ValueError('brightness_factor ({}) is not non-negative.'.format(brightness_factor))

if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')
_assert_image_tensor(img)

_assert_channels(img, [1, 3])

Expand All @@ -282,8 +282,7 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
if contrast_factor < 0:
raise ValueError('contrast_factor ({}) is not non-negative.'.format(contrast_factor))

if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')
_assert_image_tensor(img)

_assert_channels(img, [3])

Expand Down Expand Up @@ -326,9 +325,11 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
if not (-0.5 <= hue_factor <= 0.5):
raise ValueError('hue_factor ({}) is not in [-0.5, 0.5].'.format(hue_factor))

if not (isinstance(img, torch.Tensor) and _is_tensor_a_torch_image(img)):
if not (isinstance(img, torch.Tensor)):
raise TypeError('Input img should be Tensor image')

_assert_image_tensor(img)

_assert_channels(img, [3])

orig_dtype = img.dtype
Expand Down Expand Up @@ -367,8 +368,7 @@ def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
if saturation_factor < 0:
raise ValueError('saturation_factor ({}) is not non-negative.'.format(saturation_factor))

if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')
_assert_image_tensor(img)

_assert_channels(img, [3])

Expand Down Expand Up @@ -447,8 +447,7 @@ def center_crop(img: Tensor, output_size: BroadcastingList2[int]) -> Tensor:
"Please, use ``F.center_crop`` instead."
)

if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')
_assert_image_tensor(img)

_, image_width, image_height = img.size()
crop_height, crop_width = output_size
Expand Down Expand Up @@ -497,8 +496,7 @@ def five_crop(img: Tensor, size: BroadcastingList2[int]) -> List[Tensor]:
"Please, use ``F.five_crop`` instead."
)

if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')
_assert_image_tensor(img)

assert len(size) == 2, "Please provide only two dimensions (h, w) for size."

Expand Down Expand Up @@ -553,8 +551,7 @@ def ten_crop(img: Tensor, size: BroadcastingList2[int], vertical_flip: bool = Fa
"Please, use ``F.ten_crop`` instead."
)

if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')
_assert_image_tensor(img)

assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
first_five = five_crop(img, size)
Expand Down Expand Up @@ -703,8 +700,7 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con
Returns:
Tensor: Padded image.
"""
if not _is_tensor_a_torch_image(img):
raise TypeError("tensor is not a torch image.")
_assert_image_tensor(img)

if not isinstance(padding, (int, tuple, list)):
raise TypeError("Got inappropriate padding arg")
Expand Down Expand Up @@ -796,8 +792,7 @@ def resize(img: Tensor, size: List[int], interpolation: str = "bilinear") -> Ten
Returns:
Tensor: Resized image.
"""
if not _is_tensor_a_torch_image(img):
raise TypeError("tensor is not a torch image.")
_assert_image_tensor(img)

if not isinstance(size, (int, tuple, list)):
raise TypeError("Got inappropriate size arg")
Expand Down Expand Up @@ -855,8 +850,11 @@ def _assert_grid_transform_inputs(
supported_interpolation_modes: List[str],
coeffs: Optional[List[float]] = None,
):
if not (isinstance(img, torch.Tensor) and _is_tensor_a_torch_image(img)):
raise TypeError("Input img should be Tensor Image")

if not (isinstance(img, torch.Tensor)):
raise TypeError("Input img should be Tensor")

_assert_image_tensor(img)

if matrix is not None and not isinstance(matrix, list):
raise TypeError("Argument matrix should be a list")
Expand Down Expand Up @@ -1112,8 +1110,11 @@ def perspective(
Returns:
Tensor: transformed image.
"""
if not (isinstance(img, torch.Tensor) and _is_tensor_a_torch_image(img)):
raise TypeError('Input img should be Tensor Image')

if not (isinstance(img, torch.Tensor)):
raise TypeError('Input img should be Tensor.')

_assert_image_tensor(img)

_assert_grid_transform_inputs(
img,
Expand Down Expand Up @@ -1165,8 +1166,11 @@ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Te
Returns:
Tensor: An image that is blurred using gaussian kernel of given parameters
"""
if not (isinstance(img, torch.Tensor) or _is_tensor_a_torch_image(img)):
raise TypeError('img should be Tensor Image. Got {}'.format(type(img)))

if not (isinstance(img, torch.Tensor)):
raise TypeError('img should be Tensor. Got {}'.format(type(img)))

_assert_image_tensor(img)

dtype = img.dtype if torch.is_floating_point(img) else torch.float32
kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype, device=img.device)
Expand All @@ -1184,8 +1188,8 @@ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Te


def invert(img: Tensor) -> Tensor:
if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')

_assert_image_tensor(img)

if img.ndim < 3:
raise TypeError("Input image tensor should have at least 3 dimensions, but found {}".format(img.ndim))
Expand All @@ -1197,8 +1201,8 @@ def invert(img: Tensor) -> Tensor:


def posterize(img: Tensor, bits: int) -> Tensor:
if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')

_assert_image_tensor(img)

if img.ndim < 3:
raise TypeError("Input image tensor should have at least 3 dimensions, but found {}".format(img.ndim))
Expand All @@ -1211,8 +1215,8 @@ def posterize(img: Tensor, bits: int) -> Tensor:


def solarize(img: Tensor, threshold: float) -> Tensor:
if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')

_assert_image_tensor(img)

if img.ndim < 3:
raise TypeError("Input image tensor should have at least 3 dimensions, but found {}".format(img.ndim))
Expand Down Expand Up @@ -1245,8 +1249,7 @@ def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor:
if sharpness_factor < 0:
raise ValueError('sharpness_factor ({}) is not non-negative.'.format(sharpness_factor))

if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')
_assert_image_tensor(img)

_assert_channels(img, [1, 3])

Expand All @@ -1257,8 +1260,8 @@ def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor:


def autocontrast(img: Tensor) -> Tensor:
if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')

_assert_image_tensor(img)

if img.ndim < 3:
raise TypeError("Input image tensor should have at least 3 dimensions, but found {}".format(img.ndim))
Expand Down Expand Up @@ -1297,8 +1300,8 @@ def _equalize_single_image(img: Tensor) -> Tensor:


def equalize(img: Tensor) -> Tensor:
if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')

_assert_image_tensor(img)

if not (3 <= img.ndim <= 4):
raise TypeError("Input image tensor should have 3 or 4 dimensions, but found {}".format(img.ndim))
Expand Down

0 comments on commit 1a300d8

Please sign in to comment.