Skip to content

Commit

Permalink
Fix documentation and imports.
Browse files Browse the repository at this point in the history
  • Loading branch information
datumbox committed Dec 8, 2020
1 parent bbe6590 commit 8ed500f
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 23 deletions.
5 changes: 2 additions & 3 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import torchvision.transforms.functional as F
import torchvision.transforms.functional_tensor as F_t
from torch._utils_internal import get_file_path_2
from torchvision.transforms.autoaugment import AutoAugment, AutoAugmentPolicy
from numpy.testing import assert_array_almost_equal
import unittest
import math
Expand Down Expand Up @@ -1866,11 +1865,11 @@ def test_random_equalize(self):
)

def test_autoaugment(self):
for policy in AutoAugmentPolicy:
for policy in transforms.AutoAugmentPolicy:
for fill in [None, 85, (128, 128, 128)]:
random.seed(42)
img = Image.open(GRACE_HOPPER)
transform = AutoAugment(policy=policy, fill=fill)
transform = transforms.AutoAugment(policy=policy, fill=fill)
for _ in range(100):
img = transform(img)
transform.__repr__()
Expand Down
5 changes: 2 additions & 3 deletions test/test_transforms_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from torchvision import transforms as T
from torchvision.transforms import functional as F
from torchvision.transforms import InterpolationMode
from torchvision.transforms.autoaugment import AutoAugment, AutoAugmentPolicy

import numpy as np

Expand Down Expand Up @@ -631,10 +630,10 @@ def test_autoaugment(self):
tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=self.device)
batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)

for policy in AutoAugmentPolicy:
for policy in T.AutoAugmentPolicy:
for fill in [None, 85, (10, -10, 10), 0.7, [0.0, 0.0, 0.0], [1, ], 1]:
for _ in range(100):
transform = AutoAugment(policy=policy, fill=fill)
transform = T.AutoAugment(policy=policy, fill=fill)
s_transform = torch.jit.script(transform)

self._test_transform_vs_scripted(transform, s_transform, tensor)
Expand Down
1 change: 1 addition & 0 deletions torchvision/transforms/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .transforms import *
from .autoaugment import *
29 changes: 12 additions & 17 deletions torchvision/transforms/autoaugment.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from torch.jit.annotations import List, Tuple
from typing import Optional

from . import functional as F, InterpolationMode
from . import functional as F


class AutoAugmentPolicy(Enum):
Expand All @@ -26,26 +26,21 @@ class AutoAugment(torch.nn.Module):
Args:
policy (AutoAugmentPolicy): Desired policy enum defined by
:class:`torchvision.transforms.autoaugment.AutoAugmentPolicy`. Default is ``AutoAugmentPolicy.IMAGENET``.
translate (tuple, optional): tuple of maximum absolute fraction for horizontal
and vertical translations. For example translate=(a, b), then horizontal shift
is randomly sampled in the range -img_width * a < dx < img_width * a and vertical shift is
randomly sampled in the range -img_height * b < dy < img_height * b. Will not translate by default.
scale (tuple, optional): scaling factor interval, e.g (a, b), then scale is
randomly sampled from the range a <= scale <= b. Will keep original scale by default.
shear (sequence or float or int, optional): Range of degrees to select from.
If shear is a number, a shear parallel to the x axis in the range (-shear, +shear)
will be applied. Else if shear is a tuple or list of 2 values a shear parallel to the x axis in the
range (shear[0], shear[1]) will be applied. Else if shear is a tuple or list of 4 values,
a x-axis shear in (shear[0], shear[1]) and y-axis shear in (shear[2], shear[3]) will be applied.
Will not apply shear by default.
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
fill (sequence or int or float, optional): Pixel fill value for the area outside the transformed
image. If int or float, the value is used for all bands respectively.
This option is supported for PIL image and Tensor inputs.
If input is PIL Image, the options is only available for ``Pillow>=5.0.0``.
Example:
>>> t = transforms.AutoAugment()
>>> transformed = t(image)
Example as a PyTorch Transform:
>>> transform=transforms.Compose([
>>> transforms.Resize(256),
>>> transforms.AutoAugment(),
>>> transforms.ToTensor()])
"""

def __init__(self, policy: AutoAugmentPolicy = AutoAugmentPolicy.IMAGENET, fill: Optional[List[float]] = None):
Expand Down

0 comments on commit 8ed500f

Please sign in to comment.