Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

new transform functions #71

Closed
ncullen93 opened this issue Feb 26, 2017 · 10 comments
Closed

new transform functions #71

ncullen93 opened this issue Feb 26, 2017 · 10 comments

Comments

@ncullen93
Copy link

just curious.. is the plan to keep transforms such as RandomCrop, HorizontalFlip, etc only working on PIL images, or would you prefer *Tensor support as well?

Also, are yall open to adding new transforms such a rotations, shearing, shifting, zooming, etc -- for instance transforms.RandomRotation(-30,30) or transforms.HorizontalShift(0.2) or transforms.RandomZoom(0.8,1.2) .. or are these already supported elsewhere?

Would be willing to contribute in these areas. It is particularly important to combine these transforms together to only require 1 interpolation, and I have experience with that.

@fmassa
Copy link
Member

fmassa commented Feb 26, 2017

I think that some of these transforms are going to be very handy indeed!
I think that at the moment we are trying to figure out a way of better handling random transforms applied to both several images (as discussed in #9 ), and once that is sorted out, we will be adding more transforms, and any help is welcome!
One note though, I wonder is RandomZoom is not redundant with RandomCrop + Scale. Is that the case?

@ncullen93
Copy link
Author

ncullen93 commented Feb 26, 2017

Cool. Sure yes you are right about the zoom.. was just using an example :). I use transforms in my own code which apply to both input and target images (segmentation) so I definitely understand that issue.. This would require changes to the sampling code and/or Compose function. The way I do it is to first build up the affine matrix as a combination of sub-transforms, then have the target image as an optional argument and apply it to that as necessary.

e.g.

def random_transform(x, y=None):
    # build up sub-transforms
    if random_shift:
         random_shift_matrix = ..affine transform matrix..
    if random_rotation:
         ...
    # combine sub-transforms into a single transform for one interpolation
    final_transform = np.ones(3,3)
    for sub_transform in sub_transforms:
         final_transform = np.dot(final_transform, sub_transform)
    
    # apply transform to x 
    x = apply_transform(x, final_transform)
    
    # apply to y if necessary
    if y:
        y = apply_transform(y, final_transform)
    
    return x, y

You could also build up this transform matrix, then store it if you still want to keep Compose working on a single image.

I attached a full function taken from my code (largely adapted from keras/preprocessing/image.py ) which shows basically how this works.

It would be easy to pick these transforms out and combine them in the Compose function, perhaps by adding self.requires_interpolation = True to the relevant transform classes.

transform.zip

@apaszke
Copy link
Contributor

apaszke commented Feb 27, 2017

Yes, we definitely would like to add more transforms! Thanks!

Could you please upload the function as a gist?

@ncullen93
Copy link
Author

Oh good call. Main takeway from this function is just to present one possible way to combine multiple transforms.

https://gist.github.com/ncullen93/66c1803f9a3dccd1d63b041c90ecf784

@mattmacy
Copy link

At least to my own private repo I need to add the data augmentation pieces used by the 3-D Unet work: "Besides rotation, scaling and gray value augmentation, we apply a smooth dense deformation field on both data and ground truth labels. For this, we sample random vectors from a normal distribution with standard deviation of 4 in a grid with a spacing of 32 voxels in each direction and then apply a B-spline interpolation."

They have a patch for caffe - but I decided I'd rather go through the pain of reimplementing it and have the relative sanity of the pytorch API than use caffe. My question is, is augmentation (mostly) specific to medical imaging too exotic for general consumption?

http://lmb.informatik.uni-freiburg.de/resources/opensource/unet.en.html

@ncullen93
Copy link
Author

ncullen93 commented Feb 28, 2017

Hey I work with structural brain images as well! The Unet sampling is very similar to what my lab uses, but to answer your question - yes I think something like "smooth dense deformation field" is quite exotic but there is definitely a growing need for good sampling/transforms for 3D images (especially taking 3D sub-volumes or 2D slices).. If you're asking whether that type of stuff will ever be included in pytorch, I can't answer that but I hope so! Good, comprehensive sampling is a second-class citizen in most of the big DL packages.

It is very straight-forward to add your own transforms and dataset pre-processing steps in pytorch, so you should go after it and make that part at least publicly available! People will find it useful and may contribute!

To make a transform, just create a callable class:

class SmoothDeformation(object):

    def __init__(self, params):
         self.params = params
    def __call__(self, input):
         .... apply smooth deformation ...

You can string multiple transforms together in theCompose class.

Unfortunately, there is no supported way in pytorch to perform a transform on both the input and target images at the same time, but hopefully that will be supported soon.

@mattmacy
Copy link

@ncullen93 Although the caffe syntax (see below) itself is horrific I think the notion of the transforms being just another layer at train time is very appealing as it maps cleanly to running the transformations on GPU concurrently with training and it makes it logically trivial to apply the same transformation to both input and labels. Since the current setup doesn't support either and hasn't planned for it I'm inclined to go with making it an add on module. I'll re-do it "the pytorch way" down the line when it can satisfy those two requirements.

http://lmb.informatik.uni-freiburg.de/resources/opensource/3dUnet_miccai2016_no_BN.prototxt

@ncullen93
Copy link
Author

ncullen93 commented Feb 28, 2017

wow that is some spaghetti code..

Good point! You can definitely make transforms another layer by subclassing the nn.Module class or the Functional class and implementing the transform in the forward() function (see the part about creating extensions in the tutorials: https://github.com/pytorch/tutorials .. note that these functions don't have to be differentiable or have gradients, so it's completely possible to do what you're saying).

Although, transforms at the sampler level will get the benefit of using the multi-processing and queuing.. so you can do these transforms in parallel with model training.. probably more efficient that way idk. It is definitely trivial to apply the same transforms to input/labels w/ sampling... it's just not implemented right now

@mattmacy
Copy link

mattmacy commented Feb 28, 2017

Thanks for the pointers I will definitely take a stab at that. I really appreciate the prompt (and enthusiastic response). At least for a first pass I'm more comfortable just implementing it as as a simple forward function.

You may be right but how would you apply the transforms to the labels too? It needs to do so statefully to apply the same transformation to the labeled data. It doesn't seem like the API currently is set up for that. Also, in terms of using SMP, my CPU is just a 6-core Broadwell (I'm told an Intel core peaks at ~75GFLOPS) and my GPU is a Pascal Titan X which has a peak of 10.8TFLOPS and an achievable of 7.4TFLOPS. IMO, if one has a reasonable GPU the CPU's job is just to keep the GPU fed with data.

@ncullen93 ncullen93 reopened this Feb 28, 2017
@ncullen93
Copy link
Author

ncullen93 commented Feb 28, 2017

Sry accidentally closed the issue. I can think of one way - adding a co_transform argument to a Dataset subclass. You can pretty much do whatever you want to the input/target in the __call__ method.

Adapted from the current ImageFolder class:

class ImageFolder(data.Dataset):

    def __init__(self, root, transform=None, target_transform=None,
                co_transform=None, loader=default_loader):
        classes, class_to_idx = find_classes(root)
        imgs = make_dataset(root, class_to_idx)
        if len(imgs) == 0:
            raise(RuntimeError("Found 0 images in subfolders of: " + root + "\n"
                               "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))

        self.root = root
        self.imgs = imgs
        self.classes = classes
        self.class_to_idx = class_to_idx
        self.transform = transform
        self.target_transform = target_transform
        self.co_transform = co_transform # ADDED THIS
        self.loader = loader

    def __getitem__(self, index):
        path, target = self.imgs[index]
        img = self.loader(os.path.join(self.root, path))
        if self.transform is not None:
            img = self.transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)
        if self.co_transform is not None:
            img, target = self.co_transform(img, target) # ADDED THIS

        return img, target

    def __len__(self):
        return len(self.imgs)

Easy enough :). Now just create a transform that takes in both input and image..

class MyCoTransform(object):
    def __init__(self):
        pass
    def __call__(self, input, target):
        # do something to both images
        return input, target

There's a good post about this on the pytorch forum if you search for it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants