From 6b175db54643af16acc1de2594f58962c856dedb Mon Sep 17 00:00:00 2001 From: Soumith Chintala Date: Sun, 27 Nov 2016 13:31:48 -0500 Subject: [PATCH] adding lambda transform --- README.md | 9 ++++++++- test/test_transforms.py | 13 +++++++++++-- torchvision/transforms.py | 10 ++++++++++ 3 files changed, 29 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 2688e108811..f1ae795c69a 100644 --- a/README.md +++ b/README.md @@ -194,7 +194,6 @@ This is popularly used to train the Inception networks - size: size of the smaller edge - interpolation: Default: PIL.Image.BILINEAR - ### `Pad(padding, fill=0)` Pads the given image on each side with `padding` number of pixels, and the padding pixels are filled with pixel value `fill`. @@ -209,6 +208,14 @@ Given mean: (R, G, B) and std: (R, G, B), will normalize each channel of the tor - `ToTensor()` - Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] - `ToPILImage()` - Converts a torch.*Tensor of range [0, 1] and shape C x H x W or numpy ndarray of dtype=uint8, range[0, 255] and shape H x W x C to a PIL.Image of range [0, 255] +## Generic Transofrms +### `Lambda(lambda)` +Given a Python lambda, applies it to the input `img` and returns it. +For example: + +```python +transforms.Lambda(lambda x: x.add(10)) +``` # Utils diff --git a/test/test_transforms.py b/test/test_transforms.py index 0c7473b2991..fd4869e20a3 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -100,10 +100,19 @@ def test_pad(self): transforms.Pad(padding), transforms.ToTensor(), ])(img) - print(height, width, padding) - print(result.size(1), result.size(2)) assert result.size(1) == height + 2*padding assert result.size(2) == width + 2*padding + + def test_lambda(self): + trans = transforms.Lambda(lambda x: x.add(10)) + x = torch.randn(10) + y = trans(x) + assert(y.equal(torch.add(x, 10))) + + trans = transforms.Lambda(lambda x: x.add_(10)) + x = torch.randn(10) + y = trans(x) + assert(y.equal(x)) if __name__ == '__main__': diff --git a/torchvision/transforms.py b/torchvision/transforms.py index 1c8b51075c6..48be812569b 100644 --- a/torchvision/transforms.py +++ b/torchvision/transforms.py @@ -5,6 +5,7 @@ from PIL import Image, ImageOps import numpy as np import numbers +import types class Compose(object): """ Composes several transforms together. @@ -126,6 +127,15 @@ def __init__(self, padding, fill=0): def __call__(self, img): return ImageOps.expand(img, border=self.padding, fill=self.fill) +class Lambda(object): + """Applies a lambda as a transform""" + def __init__(self, lambd): + assert type(lambd) is types.LambdaType + self.lambd = lambd + + def __call__(self, img): + return self.lambd(img) + class RandomCrop(object): """Crops the given PIL.Image at a random location to have a region of