Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Check num of channels on adjust_* transformations #3069

Merged
merged 3 commits into from
Dec 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions test/common_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,13 +339,13 @@ def freeze_rng_state():
class TransformsTester(unittest.TestCase):

def _create_data(self, height=3, width=3, channels=3, device="cpu"):
tensor = torch.randint(0, 255, (channels, height, width), dtype=torch.uint8, device=device)
tensor = torch.randint(0, 256, (channels, height, width), dtype=torch.uint8, device=device)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

randint is exclusive on the upper bound, so the right value here is 256. I opted in fixing this small bug in-place to avoid doing unnecessary CI runs.

pil_img = Image.fromarray(tensor.permute(1, 2, 0).contiguous().cpu().numpy())
return tensor, pil_img

def _create_data_batch(self, height=3, width=3, channels=3, num_samples=4, device="cpu"):
batch_tensor = torch.randint(
0, 255,
0, 256,
(num_samples, channels, height, width),
dtype=torch.uint8,
device=device
Expand Down
28 changes: 21 additions & 7 deletions torchvision/transforms/functional_tensor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import Optional, Dict, Tuple
from typing import Optional, Tuple

import torch
from torch import Tensor
Expand Down Expand Up @@ -45,6 +45,12 @@ def _max_value(dtype: torch.dtype) -> float:
return max_value.item()


def _assert_channels(img: Tensor, permitted: List[int]) -> None:
c = _get_image_num_channels(img)
if c not in permitted:
raise TypeError("Input image tensor permitted channel values are {}, but found {}".format(permitted, c))


def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor:
"""PRIVATE METHOD. Convert a tensor image to the given ``dtype`` and scale the values accordingly

Expand Down Expand Up @@ -210,9 +216,7 @@ def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor:
"""
if img.ndim < 3:
raise TypeError("Input image tensor should have at least 3 dimensions, but found {}".format(img.ndim))
c = img.shape[-3]
if c != 3:
raise TypeError("Input image tensor should 3 channels, but found {}".format(c))
_assert_channels(img, [3])

if num_output_channels not in (1, 3):
raise ValueError('num_output_channels should be either 1 or 3')
Expand All @@ -230,7 +234,7 @@ def rgb_to_grayscale(img: Tensor, num_output_channels: int = 1) -> Tensor:


def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
"""PRIVATE METHOD. Adjust brightness of an RGB image.
"""PRIVATE METHOD. Adjust brightness of a Grayscale or RGB image.

.. warning::

Expand All @@ -252,6 +256,8 @@ def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')

_assert_channels(img, [1, 3])

return _blend(img, torch.zeros_like(img), brightness_factor)


Expand All @@ -278,14 +284,16 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')

_assert_channels(img, [3])

dtype = img.dtype if torch.is_floating_point(img) else torch.float32
mean = torch.mean(rgb_to_grayscale(img).to(dtype), dim=(-3, -2, -1), keepdim=True)

return _blend(img, mean, contrast_factor)


def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
"""PRIVATE METHOD. Adjust hue of an image.
"""PRIVATE METHOD. Adjust hue of an RGB image.

.. warning::

Expand Down Expand Up @@ -320,6 +328,8 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
if not (isinstance(img, torch.Tensor) and _is_tensor_a_torch_image(img)):
raise TypeError('Input img should be Tensor image')

_assert_channels(img, [3])

orig_dtype = img.dtype
if img.dtype == torch.uint8:
img = img.to(dtype=torch.float32) / 255.0
Expand Down Expand Up @@ -359,11 +369,13 @@ def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
if not _is_tensor_a_torch_image(img):
raise TypeError('tensor is not a torch image.')

_assert_channels(img, [3])

return _blend(img, rgb_to_grayscale(img), saturation_factor)


def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
r"""PRIVATE METHOD. Adjust gamma of an RGB image.
r"""PRIVATE METHOD. Adjust gamma of a Grayscale or RGB image.

.. warning::

Expand Down Expand Up @@ -391,6 +403,8 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
if not isinstance(img, torch.Tensor):
raise TypeError('Input img should be a Tensor.')

_assert_channels(img, [1, 3])

if gamma < 0:
raise ValueError('Gamma should be a non-negative real number')

Expand Down