Skip to content

Commit

Permalink
Merge pull request #5 from pytorch/tests
Browse files Browse the repository at this point in the history
adding unit tests for image transforms
  • Loading branch information
soumith authored Nov 16, 2016
2 parents 44da562 + bd62df6 commit 650eb32
Show file tree
Hide file tree
Showing 3 changed files with 209 additions and 28 deletions.
50 changes: 42 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,14 +144,6 @@ The data is preprocessed [as described here](https://github.com/facebook/fb.resn
Transforms are common image transforms.
They can be chained together using `transforms.Compose`

- `ToTensor()` - converts PIL Image to Tensor
- `Normalize(mean, std)` - normalizes the image given mean, std (for example: mean = [0.3, 1.2, 2.1])
- `Scale(size, interpolation=Image.BILINEAR)` - Scales the smaller image edge to the given size. Interpolation modes are options from PIL
- `CenterCrop(size)` - center-crops the image to the given size
- `RandomCrop(size)` - Random crops the image to the given size.
- `RandomHorizontalFlip()` - hflip the image with probability 0.5
- `RandomSizedCrop(size, interpolation=Image.BILINEAR)` - Random crop with size 0.08-1 and aspect ratio 3/4 - 4/3 (Inception-style)

### `transforms.Compose`

One can compose several transforms together.
Expand All @@ -166,3 +158,45 @@ transform = transforms.Compose([
std = [ 0.229, 0.224, 0.225 ]),
])
```

## Transforms on PIL.Image

### `Scale(size, interpolation=Image.BILINEAR)`
Rescales the input PIL.Image to the given 'size'.
'size' will be the size of the smaller edge.

For example, if height > width, then image will be
rescaled to (size * height / width, size)
- size: size of the smaller edge
- interpolation: Default: PIL.Image.BILINEAR

### `CenterCrop(size)` - center-crops the image to the given size
Crops the given PIL.Image at the center to have a region of
the given size. size can be a tuple (target_height, target_width)
or an integer, in which case the target will be of a square shape (size, size)

### `RandomCrop(size)`
Crops the given PIL.Image at a random location to have a region of
the given size. size can be a tuple (target_height, target_width)
or an integer, in which case the target will be of a square shape (size, size)

### `RandomHorizontalFlip()`
Randomly horizontally flips the given PIL.Image with a probability of 0.5

### `RandomSizedCrop(size, interpolation=Image.BILINEAR)`
Random crop the given PIL.Image to a random size of (0.08 to 1.0) of the original size
and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio

This is popularly used to train the Inception networks
- size: size of the smaller edge
- interpolation: Default: PIL.Image.BILINEAR

## Transforms on torch.*Tensor

### `Normalize(mean, std)`
Given mean: (R, G, B) and std: (R, G, B), will normalize each channel of the torch.*Tensor, i.e. channel = (channel - mean) / std

## Conversion Transforms
- `ToTensor()` - Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]
- `ToPILImage()` - Converts a torch.*Tensor of range [0, 1] and shape C x H x W or numpy ndarray of dtype=uint8, range[0, 255] and shape H x W x C to a PIL.Image of range [0, 255]

87 changes: 87 additions & 0 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import numpy as np
import unittest
import random

class Tester(unittest.TestCase):
def test_crop(self):
height = random.randint(10, 32) * 2
width = random.randint(10, 32) * 2
oheight = random.randint(5, (height - 2) / 2) * 2
owidth = random.randint(5, (width - 2) / 2) * 2

img = torch.ones(3, height, width)
oh1 = (height - oheight) / 2
ow1 = (width - owidth) / 2
imgnarrow = img[:, oh1 :oh1 + oheight, ow1 :ow1 + owidth]
imgnarrow.fill_(0)
result = transforms.Compose([
transforms.ToPILImage(),
transforms.CenterCrop((oheight, owidth)),
transforms.ToTensor(),
])(img)
assert result.sum() == 0, "height: " + str(height) + " width: " \
+ str( width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
oheight += 1
owidth += 1
result = transforms.Compose([
transforms.ToPILImage(),
transforms.CenterCrop((oheight, owidth)),
transforms.ToTensor(),
])(img)
sum1 = result.sum()
assert sum1 > 1, "height: " + str(height) + " width: " \
+ str( width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
oheight += 1
owidth += 1
result = transforms.Compose([
transforms.ToPILImage(),
transforms.CenterCrop((oheight, owidth)),
transforms.ToTensor(),
])(img)
sum2 = result.sum()
assert sum2 > 0, "height: " + str(height) + " width: " \
+ str( width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)
assert sum2 > sum1, "height: " + str(height) + " width: " \
+ str( width) + " oheight: " + str(oheight) + " owidth: " + str(owidth)

def test_scale(self):
height = random.randint(24, 32) * 2
width = random.randint(24, 32) * 2
osize = random.randint(5, 12) * 2

img = torch.ones(3, height, width)
result = transforms.Compose([
transforms.ToPILImage(),
transforms.Scale(osize),
transforms.ToTensor(),
])(img)
# print img.size()
# print 'output size:', osize
# print result.size()
assert osize in result.size()
if height < width:
assert result.size(1) <= result.size(2)
elif width < height:
assert result.size(1) >= result.size(2)

def test_random_crop(self):
height = random.randint(10, 32) * 2
width = random.randint(10, 32) * 2
oheight = random.randint(5, (height - 2) / 2) * 2
owidth = random.randint(5, (width - 2) / 2) * 2
img = torch.ones(3, height, width)
result = transforms.Compose([
transforms.ToPILImage(),
transforms.RandomCrop((oheight, owidth)),
transforms.ToTensor(),
])(img)
assert result.size(1) == oheight
assert result.size(2) == owidth



if __name__ == '__main__':
unittest.main()
100 changes: 80 additions & 20 deletions torchvision/transforms.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,19 @@
from __future__ import division
import torch
import math
import random
from PIL import Image
import numpy as np

import numbers

class Compose(object):
""" Composes several transforms together.
For example:
>>> transforms.Compose([
>>> transforms.CenterCrop(10),
>>> transforms.ToTensor(),
>>> ])
"""
def __init__(self, transforms):
self.transforms = transforms

Expand All @@ -16,6 +24,8 @@ def __call__(self, img):


class ToTensor(object):
""" Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255]
to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """
def __call__(self, pic):
if isinstance(pic, np.ndarray):
# handle numpy array
Expand All @@ -24,24 +34,50 @@ def __call__(self, pic):
# handle PIL Image
img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
img = img.view(pic.size[0], pic.size[1], 3)
# put it in CHW format
# put it from WHC to CHW format
# yikes, this transpose takes 80% of the loading time/CPU
img = img.transpose(0, 2).transpose(1, 2).contiguous()
return img.float()
img = img.transpose(0, 2).contiguous()
return img.float().div(255)

class ToPILImage(object):
""" Converts a torch.*Tensor of range [0, 1] and shape C x H x W
or numpy ndarray of dtype=uint8, range[0, 255] and shape H x W x C
to a PIL.Image of range [0, 255]
"""
def __call__(self, pic):
if isinstance(pic, np.ndarray):
# handle numpy array
img = Image.fromarray(pic)
else:
npimg = pic.mul(255).byte().numpy()
npimg = np.transpose(npimg, (1,2,0))
img = Image.fromarray(npimg)
return img

class Normalize(object):
""" Given mean: (R, G, B) and std: (R, G, B),
will normalize each channel of the torch.*Tensor, i.e.
channel = (channel - mean) / std
"""
def __init__(self, mean, std):
self.mean = mean
self.std = std

def __call__(self, tensor):
# TODO: make efficient
for t, m, s in zip(tensor, self.mean, self.std):
t.sub_(m).div_(s)
return tensor


class Scale(object):
"Scales the smaller edge to size"
""" Rescales the input PIL.Image to the given 'size'.
'size' will be the size of the smaller edge.
For example, if height > width, then image will be
rescaled to (size * height / width, size)
size: size of the smaller edge
interpolation: Default: PIL.Image.BILINEAR
"""
def __init__(self, size, interpolation=Image.BILINEAR):
self.size = size
self.interpolation = interpolation
Expand All @@ -51,52 +87,76 @@ def __call__(self, img):
if (w <= h and w == self.size) or (h <= w and h == self.size):
return img
if w < h:
return img.resize((w, int(round(h / w * self.size))), self.interpolation)
ow = self.size
oh = int(self.size * h / w)
return img.resize((ow, oh), self.interpolation)
else:
return img.resize((int(round(w / h * self.size)), h), self.interpolation)
oh = self.size
ow = int(self.size * w / h)
return img.resize((ow, oh), self.interpolation)


class CenterCrop(object):
"Crop to centered rectangle"
"""Crops the given PIL.Image at the center to have a region of
the given size. size can be a tuple (target_height, target_width)
or an integer, in which case the target will be of a square shape (size, size)
"""
def __init__(self, size):
self.size = size
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
self.size = size

def __call__(self, img):
w, h = img.size
x1 = int(round((w - self.size) / 2))
y1 = int(round((h - self.size) / 2))
return img.crop((x1, y1, x1 + self.size, y1 + self.size))
th, tw = self.size
x1 = int(round((w - tw) / 2))
y1 = int(round((h - th) / 2))
return img.crop((x1, y1, x1 + tw, y1 + th))


class RandomCrop(object):
"Random crop form larger image with optional zero padding"
"""Crops the given PIL.Image at a random location to have a region of
the given size. size can be a tuple (target_height, target_width)
or an integer, in which case the target will be of a square shape (size, size)
"""
def __init__(self, size, padding=0):
self.size = size
if isinstance(size, numbers.Number):
self.size = (int(size), int(size))
else:
self.size = size
self.padding = padding

def __call__(self, img):
if self.padding > 0:
raise NotImplementedError()

w, h = img.size
if w == self.size and h == self.size:
th, tw = self.size
if w == tw and h == th:
return img

x1 = random.randint(0, w - self.size)
y1 = random.randint(0, h - self.size)
return img.crop((x1, y1, x1 + self.size, y1 + self.size))
x1 = random.randint(0, w - tw)
y1 = random.randint(0, h - th)
return img.crop((x1, y1, x1 + tw, y1 + th))


class RandomHorizontalFlip(object):
"Horizontal flip with 0.5 probability"
"""Randomly horizontally flips the given PIL.Image with a probability of 0.5
"""
def __call__(self, img):
if random.random() < 0.5:
return img.transpose(Image.FLIP_LEFT_RIGHT)
return img


class RandomSizedCrop(object):
"Random crop with size 0.08-1 and aspect ratio 3/4 - 4/3 (Inception-style)"
"""Random crop the given PIL.Image to a random size of (0.08 to 1.0) of the original size
and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio
This is popularly used to train the Inception networks
size: size of the smaller edge
interpolation: Default: PIL.Image.BILINEAR
"""
def __init__(self, size, interpolation=Image.BILINEAR):
self.size = size
self.interpolation = interpolation
Expand Down

0 comments on commit 650eb32

Please sign in to comment.