Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cleanup functional_tensor.py (#3159) #3171

Merged
merged 13 commits into from
Dec 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions test/test_functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

from common_utils import TransformsTester

from typing import Dict, List, Tuple


NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC

Expand All @@ -34,6 +36,28 @@ def _test_fn_on_batch(self, batch_tensors, fn, **fn_kwargs):
s_transformed_batch = scripted_fn(batch_tensors, **fn_kwargs)
self.assertTrue(transformed_batch.allclose(s_transformed_batch))

def test_assert_image_tensor(self):
shape = (100,)
tensor = torch.rand(*shape, dtype=torch.float, device=self.device)

list_of_methods = [(F_t._get_image_size, (tensor, )), (F_t.vflip, (tensor, )),
(F_t.hflip, (tensor, )), (F_t.crop, (tensor, 1, 2, 4, 5)),
(F_t.adjust_brightness, (tensor, 0.)), (F_t.adjust_contrast, (tensor, 1.)),
(F_t.adjust_hue, (tensor, -0.5)), (F_t.adjust_saturation, (tensor, 2.)),
(F_t.center_crop, (tensor, [10, 11])), (F_t.five_crop, (tensor, [10, 11])),
(F_t.ten_crop, (tensor, [10, 11])), (F_t.pad, (tensor, [2, ], 2, "constant")),
(F_t.resize, (tensor, [10, 11])), (F_t.perspective, (tensor, [0.2, ])),
(F_t.gaussian_blur, (tensor, (2, 2), (0.7, 0.5))),
(F_t.invert, (tensor, )), (F_t.posterize, (tensor, 0)),
(F_t.solarize, (tensor, 0.3)), (F_t.adjust_sharpness, (tensor, 0.3)),
(F_t.autocontrast, (tensor, )), (F_t.equalize, (tensor, ))]

for func, args in list_of_methods:
with self.assertRaises(Exception) as context:
func(*args)

self.assertTrue('Tensor is not a torch image.' in str(context.exception))

def test_vflip(self):
script_vflip = torch.jit.script(F.vflip)

Expand Down
91 changes: 47 additions & 44 deletions torchvision/transforms/functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,15 @@ def _is_tensor_a_torch_image(x: Tensor) -> bool:
return x.ndim >= 2


def _assert_image_tensor(img):
if not _is_tensor_a_torch_image(img):
raise TypeError("Tensor is not a torch image.")


def _get_image_size(img: Tensor) -> List[int]:
"""Returns (w, h) of tensor image"""
if _is_tensor_a_torch_image(img):
return [img.shape[-1], img.shape[-2]]
raise TypeError("Unexpected input type")
_assert_image_tensor(img)
return [img.shape[-1], img.shape[-2]]


def _get_image_num_channels(img: Tensor) -> int:
Expand Down Expand Up @@ -143,8 +147,7 @@ def vflip(img: Tensor) -> Tensor:
Returns:
Tensor: Vertically flipped image Tensor.
"""
if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')
_assert_image_tensor(img)

return img.flip(-2)

Expand All @@ -163,8 +166,7 @@ def hflip(img: Tensor) -> Tensor:
Returns:
Tensor: Horizontally flipped image Tensor.
"""
if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')
_assert_image_tensor(img)

return img.flip(-1)

Expand All @@ -187,8 +189,7 @@ def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
Returns:
Tensor: Cropped image.
"""
if not _is_tensor_a_torch_image(img):
raise TypeError("tensor is not a torch image.")
_assert_image_tensor(img)

return img[..., top:top + height, left:left + width]

Expand Down Expand Up @@ -254,8 +255,7 @@ def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
if brightness_factor < 0:
raise ValueError('brightness_factor ({}) is not non-negative.'.format(brightness_factor))

if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')
_assert_image_tensor(img)

_assert_channels(img, [1, 3])

Expand All @@ -282,8 +282,7 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
if contrast_factor < 0:
raise ValueError('contrast_factor ({}) is not non-negative.'.format(contrast_factor))

if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')
_assert_image_tensor(img)

_assert_channels(img, [3])

Expand Down Expand Up @@ -326,9 +325,11 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
if not (-0.5 <= hue_factor <= 0.5):
raise ValueError('hue_factor ({}) is not in [-0.5, 0.5].'.format(hue_factor))

if not (isinstance(img, torch.Tensor) and _is_tensor_a_torch_image(img)):
if not (isinstance(img, torch.Tensor)):
avijit9 marked this conversation as resolved.
Show resolved Hide resolved
raise TypeError('Input img should be Tensor image')

_assert_image_tensor(img)

_assert_channels(img, [3])

orig_dtype = img.dtype
Expand Down Expand Up @@ -367,8 +368,7 @@ def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
if saturation_factor < 0:
raise ValueError('saturation_factor ({}) is not non-negative.'.format(saturation_factor))

if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')
_assert_image_tensor(img)

_assert_channels(img, [3])

Expand Down Expand Up @@ -447,8 +447,7 @@ def center_crop(img: Tensor, output_size: BroadcastingList2[int]) -> Tensor:
"Please, use ``F.center_crop`` instead."
)

if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')
_assert_image_tensor(img)

_, image_width, image_height = img.size()
crop_height, crop_width = output_size
Expand Down Expand Up @@ -497,8 +496,7 @@ def five_crop(img: Tensor, size: BroadcastingList2[int]) -> List[Tensor]:
"Please, use ``F.five_crop`` instead."
)

if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')
_assert_image_tensor(img)

assert len(size) == 2, "Please provide only two dimensions (h, w) for size."

Expand Down Expand Up @@ -553,8 +551,7 @@ def ten_crop(img: Tensor, size: BroadcastingList2[int], vertical_flip: bool = Fa
"Please, use ``F.ten_crop`` instead."
)

if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')
_assert_image_tensor(img)

assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
first_five = five_crop(img, size)
Expand Down Expand Up @@ -703,8 +700,7 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con
Returns:
Tensor: Padded image.
"""
if not _is_tensor_a_torch_image(img):
raise TypeError("tensor is not a torch image.")
_assert_image_tensor(img)

if not isinstance(padding, (int, tuple, list)):
raise TypeError("Got inappropriate padding arg")
Expand Down Expand Up @@ -796,8 +792,7 @@ def resize(img: Tensor, size: List[int], interpolation: str = "bilinear") -> Ten
Returns:
Tensor: Resized image.
"""
if not _is_tensor_a_torch_image(img):
raise TypeError("tensor is not a torch image.")
_assert_image_tensor(img)

if not isinstance(size, (int, tuple, list)):
raise TypeError("Got inappropriate size arg")
Expand Down Expand Up @@ -855,8 +850,11 @@ def _assert_grid_transform_inputs(
supported_interpolation_modes: List[str],
coeffs: Optional[List[float]] = None,
):
if not (isinstance(img, torch.Tensor) and _is_tensor_a_torch_image(img)):
raise TypeError("Input img should be Tensor Image")

if not (isinstance(img, torch.Tensor)):
avijit9 marked this conversation as resolved.
Show resolved Hide resolved
raise TypeError("Input img should be Tensor")

_assert_image_tensor(img)

if matrix is not None and not isinstance(matrix, list):
raise TypeError("Argument matrix should be a list")
Expand Down Expand Up @@ -1112,8 +1110,11 @@ def perspective(
Returns:
Tensor: transformed image.
"""
if not (isinstance(img, torch.Tensor) and _is_tensor_a_torch_image(img)):
raise TypeError('Input img should be Tensor Image')

if not (isinstance(img, torch.Tensor)):
avijit9 marked this conversation as resolved.
Show resolved Hide resolved
raise TypeError('Input img should be Tensor.')

_assert_image_tensor(img)

_assert_grid_transform_inputs(
img,
Expand Down Expand Up @@ -1165,8 +1166,11 @@ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Te
Returns:
Tensor: An image that is blurred using gaussian kernel of given parameters
"""
if not (isinstance(img, torch.Tensor) or _is_tensor_a_torch_image(img)):
raise TypeError('img should be Tensor Image. Got {}'.format(type(img)))

if not (isinstance(img, torch.Tensor)):
avijit9 marked this conversation as resolved.
Show resolved Hide resolved
raise TypeError('img should be Tensor. Got {}'.format(type(img)))

_assert_image_tensor(img)

dtype = img.dtype if torch.is_floating_point(img) else torch.float32
kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype, device=img.device)
Expand All @@ -1184,8 +1188,8 @@ def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: List[float]) -> Te


def invert(img: Tensor) -> Tensor:
if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')

_assert_image_tensor(img)

if img.ndim < 3:
raise TypeError("Input image tensor should have at least 3 dimensions, but found {}".format(img.ndim))
Expand All @@ -1197,8 +1201,8 @@ def invert(img: Tensor) -> Tensor:


def posterize(img: Tensor, bits: int) -> Tensor:
if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')

_assert_image_tensor(img)

if img.ndim < 3:
raise TypeError("Input image tensor should have at least 3 dimensions, but found {}".format(img.ndim))
Expand All @@ -1211,8 +1215,8 @@ def posterize(img: Tensor, bits: int) -> Tensor:


def solarize(img: Tensor, threshold: float) -> Tensor:
if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')

_assert_image_tensor(img)

if img.ndim < 3:
raise TypeError("Input image tensor should have at least 3 dimensions, but found {}".format(img.ndim))
Expand Down Expand Up @@ -1245,8 +1249,7 @@ def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor:
if sharpness_factor < 0:
raise ValueError('sharpness_factor ({}) is not non-negative.'.format(sharpness_factor))

if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')
_assert_image_tensor(img)

_assert_channels(img, [1, 3])

Expand All @@ -1257,8 +1260,8 @@ def adjust_sharpness(img: Tensor, sharpness_factor: float) -> Tensor:


def autocontrast(img: Tensor) -> Tensor:
if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')

_assert_image_tensor(img)

if img.ndim < 3:
raise TypeError("Input image tensor should have at least 3 dimensions, but found {}".format(img.ndim))
Expand Down Expand Up @@ -1297,8 +1300,8 @@ def _equalize_single_image(img: Tensor) -> Tensor:


def equalize(img: Tensor) -> Tensor:
if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')

_assert_image_tensor(img)

if not (3 <= img.ndim <= 4):
raise TypeError("Input image tensor should have 3 or 4 dimensions, but found {}".format(img.ndim))
Expand Down