Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Random transforms for both input and target? #9

Closed
fmassa opened this issue Nov 28, 2016 · 36 comments
Closed

Random transforms for both input and target? #9

fmassa opened this issue Nov 28, 2016 · 36 comments

Comments

@fmassa
Copy link
Member

fmassa commented Nov 28, 2016

In some scenarios (like semantic segmentation), we might want to apply the same random transform to both the input and the GT labels (cropping, flip, rotation, etc).
I think we can get this behaviour emulated in a segmentation dataset class by resetting the random seed before calling the transform for the labels.
This sound a bit fragile though.

One other possibility is to have the transforms accept both inputs and targets as arguments.

Do you have any better solutions?

@soumith
Copy link
Member

soumith commented Nov 29, 2016

hmm, good point.
We could make a stateful transform object that resets the random seed every 2 steps.
To do this you simply have to create a class with the call operator that is stateful.

For example:

trans = transforms.tworandom(transforms.compose([...]))
dataset = dset.SegmentationDset(...,transform=trans, target_transform=trans)

@colesbury
Copy link
Member

colesbury commented Nov 29, 2016

Ewww, that's a fragile ugly hack. I also don't see how it will help, since the target isn't an image.

Can't you just extend the CocoDataset? Generate the random parameters and then apply the same logical op to both image and target:

class TransformedCocoDataset(CocoDataset):
  def __getitem__(self, i):
    input, target = super(TransformedCocoDataset, self).__getitem__(i)
    hflip = random.random() < 0.5
    if hflip:
      image.hflip(image)
      # flip target rectangle
    # etc.
    return input, target

Perhaps we should provide similar operations that work on bounding boxes?

@colesbury
Copy link
Member

Oh -- I misread the issue. I guess the target is an image for segmentation. It still seems to me like you want to generate the random parameters once and apply the operation twice. For things that aren't trivial (like RandomSizedCrop), we may want to refactor out the part that generates the random parameters to the image op. Probably don't have to do anything for trivial ops like horizontal-flip.

@fmassa
Copy link
Member Author

fmassa commented Dec 25, 2016

I think the solution proposed by @colesbury about sub-classing on the dataset is the most general one.
In a maybe cleaner way, this solution is actually equivalent to using a transformdataset from tnt with a single callable instead of a dict of callables.

Also, the current way of passing transform and target_transform in every dataset is equivalent to using a transformdataset with dicts of transforms as input (and returning dicts as well instead of tuples).

As such, are you ok if we merge tnt datasets into core, and remove transform and target_transform arguments from vision datasets?

@ClementPinard
Copy link

The way I see it in @colesbury code, we will have the same probleme when trying to compose different transform functions, because random parameters are created within the call function. we won't be able to customize transform functions, and will have to create a subdataset per set of transform functions we want to try.

What about special transformation for both imputs and targets ? This may create some duplicates functions like randomcrop for image based target (maybe add target as an optional argument ?) , but i don't see how we would apply properly coherent transformations for input and target.

We could also give a seed as an argument in addition to img in getitem , but we are not garanteed transform and target_transform will be constructed in a dual manner such that random variables will have coherent effects (like flipping both inputs and targets, or translating BB coordinates according to input img transformation)

I feel like there should 3 types of transform : transform_input that deals with transformations that are independent of target, like flip-crop for classification, transform_target idem for target and lastly co_transform(sorry about bad terminology) that deals with dependent transformations and must take input and target as arguments and I believe concerns the vast majority of data augmentation for more geometric problems, such as detection, depth estimation, flow, etc

@abeschneider
Copy link

abeschneider commented Feb 6, 2017

I actually just ran into this problem myself. Another potential solution is to allow transforms to perform the same operation over a list of images. Thus, CocoDataset would look like:

if self.transform is not None:
    img, label = self.transform(img, label)

and the transform itself may look like:

    def __call__(self, *images):
        # perform some check to make sure images are all the same size
        if self.padding > 0:
            images = [ImageOps.expand(img, border=self.padding, fill=0) for im in images]

        w, h = images[0].size
        th, tw = self.size
        if w == tw and h == th:
            return images

        x1 = random.randint(0, w - tw)
        y1 = random.randint(0, h - th)
        return [img.crop((x1, y1, x1 + tw, y1 + th)) for img in images]

It may be possible to further abstract this out and create a container class for an image that automatically applies the same operations across the collection. That way the operations could be agnostic to what they are actually operating on.

@fmassa
Copy link
Member Author

fmassa commented Feb 26, 2017

I gave it a shot on implementing some generic random transform that could be applied so several inputs (even for example images + bounding boxes). An example implementation can be found here.
Here are the main ideas:

  • Allow to pass an arbitrarily nested tuple of inputs (can be image, labels, bboxes, masks, etc), and the Compose transform can accept the tuple instead of individual elements
  • The sequence of transforms should be a list as before, but if an entry is a list, then it is applied to the corresponding instance of the tuple of inputs, and if it's not, then it's applied to all the elements at the same time.
  • Every random transform should consist of 2 classes: a random generator, that generates the elements during the function call, and the transform itself, that can accept the generator object and uses the random parameters generated from it. It can thus allow the same transformation to be applied to different inputs by reusing the same implementation, and it's not needed to implement say RandomCrop2Images etc.
  • It's backward compatible with previous transforms.

It's a bit rough (and maybe a bit too complicated), but it could handle the cases we have mentioned in this thread.

@bodokaiser
Copy link
Contributor

@fmassa does your proposal also consider transforms which depend on input and target image?

This would be helpful when you have for example want to crop your input image to the size of the target image.

@fmassa
Copy link
Member Author

fmassa commented Mar 2, 2017

@bodokaiser Yes, this should be handled as well.
All you need to do is to create a transform that take both input and target and perform the operation, returning input and target. Here is an example:

class MyJointOp(object):
    def __call__(self, input, target):
        # perform something on input and target
        return input, target

and then you would use it as follows (using my transforms), supposing that your dataset outputs input and target:

flip_gen = mytransforms.RandomFlipGenerator()
mytransforms.Compose([
    MyJointOp(), # take all inputs/targets
    [transforms.ColorAugmentation(), None], # Color augmentation in input, no operation in target
    flip_gen, # get a random seed for the flip transform and returns the identity
    [transforms.RandomFlip(flip_gen), transforms.RandomFlip(flip_gen)], # apply the same flip to input and target
    ...
])

@bodokaiser
Copy link
Contributor

@fmassa Do you plan to merge your implementation into torchvision?

@fmassa
Copy link
Member Author

fmassa commented Mar 10, 2017

@bodokaiser for the moment I'll only merge the part that separates the random parameters from the transforms, so that you can apply the same random transform to different inputs / targets.
We decided that my proposal brings too much complexity, and that it's up to the user to subclass their dataset and add the random transforms that you want.

@bodokaiser
Copy link
Contributor

@fmassa I understand the complexity argument moreover supporting same random parameters for both transforms is already a huge plus.

I think for my problems I will end up just moving some preprocessing to map(preprocess, loader) so that I can filter out bad samples from my dataset or other transforms which require both input and target data.

@oeway
Copy link

oeway commented Apr 3, 2017

Hi guys, for the same issue, I proposed a similar solution by sharing the seed across input and target transform functions in Keras: keras-team/keras#3338 . As mentioned above, that's too complicated to handle, my implementation is buggy and not thread-safe(the input and target can mis-synchronised).

When moving to torch, I found the easy and robust way to implementing this kind of transformations is by merging and splitting image channels.

The idea is to use one transform for handling both input and target images, by using Merge() and Split, we can merge input and target images along channel-dimension, and then we can apply the transform, after, we split it back.

Here is my implementation:

transform = EnhancedCompose([
    Merge(),              # merge input and target along the channel axis
    ElasticTransform(),
    RandomRotate(),
    Split([0,1],[1,2]),  # split into 2 images
    [CenterCropNumpy(size=input_shape), CenterCropNumpy(size=target_shape)],
    [NormalizeNumpy(), None],
    [Lambda(to_tensor), Lambda(to_tensor)]
])
class EnhancedCompose(object):
    """Composes several transforms together, support separate transformations for multiple input.
    """
    def __init__(self, transforms):
        self.transforms = transforms

    def __call__(self, img):
        for t in self.transforms:
            if isinstance(t, collections.Sequence):
                assert isinstance(img, collections.Sequence) and len(img) == len(t), "size of image group and transform group does not fit"
                tmp_ = []
                for i, im_ in enumerate(img):
                    if callable(t[i]):
                        tmp_.append(t[i](im_))
                    else:
                        tmp_.append(im_)
                img = tmp_
            elif callable(t):
                img = t(img)
            elif t is None:
                continue
            else:
                raise Exception('unexpected type')                
        return img

class Merge(object):
    """Merge a group of images
    """
    def __init__(self, axis=-1):
        self.axis = axis
    def __call__(self, images):
        if isinstance(images, collections.Sequence) or isinstance(images, np.ndarray):
            assert all([isinstance(i, np.ndarray) for i in images]), 'only numpy array is supported'
            shapes = [list(i.shape) for i in images]
            for s in shapes:
                s[self.axis] = None
            assert all([s==shapes[0] for s in shapes]), 'shapes must be the same except the merge axis'
            return np.concatenate(images, axis=self.axis)
        else:
            raise Exception("obj is not a sequence (list, tuple, etc)")

class Split(object):
    """Split images into individual images
    """
    def __init__(self, *slices, **kwargs):
        assert isinstance(slices, collections.Sequence)
        slices_ = []
        for s in slices:
            if isinstance(s, collections.Sequence):
                slices_.append(slice(*s))
            else:
                slices_.append(s)
        assert all([isinstance(s, slice) for s in slices_]), 'slices must be consist of slice instances'
        self.slices = slices_
        self.axis = kwargs.get('axis', -1)

    def __call__(self, image):
        if isinstance(image, np.ndarray):
            ret = []
            for s in self.slices:
                sl = [slice(None)]*image.ndim
                sl[self.axis] = s
                ret.append(image[sl])
            return ret
        else:
            raise Exception("obj is not an numpy array")

Also note that, by doing this, I would propose to implement all the transformations with numpy and scipy.ndimage in torchvision, which is more powerful than PIL, and also to get rid of the limitations on the channel number and image mode that PIL can handle.

And this implementation can also support more than two image pairs, meaning sometimes we can have multiple inputs and outputs. For example, some kind of sample weight map which need to be transformed at the same time.

I have been using my implementation for a while and it helped me a lot.
Let me know how do you think. if there is enough people interested, I could try to do a PR.

@lwye
Copy link

lwye commented Apr 11, 2017

@oeway This idea is great for my current project that requires uncertain number of targets for an image. The transform functions seem to be different from that in torchvision. Could you provide the whole implementation for this project. I would like to give it a try.
Thanks

@oeway
Copy link

oeway commented Apr 12, 2017

@Iwye Thanks for your interest. The interface of transform functions are the same as torchvision (__init__ and __call__ interface), but you are right because I handle numpy arrays. I will try to clean up my code and put on git, hopefully soon.

@lwye
Copy link

lwye commented Apr 12, 2017

@oeway Great. It will be more flexible to operate on numpy in my case. Look forward to it.

@catalystfrank
Copy link

So does anyone have any ideas about how to perform this transform.ColorAugmentation()? Any links of unmerger PR / fork is OK. Thanks in advance.

@fmassa
Copy link
Member Author

fmassa commented Apr 13, 2017

@catalystfrank you can find color augmentation transforms in here.

@bermanmaxim
Copy link

@oeway one issue I see with that approach concerns the image-resizing random transforms: while the input image typically use bilinear transformation, the discretely-labelled target uses a neirest-neighbour assignment.

@oeway
Copy link

oeway commented Apr 18, 2017 via email

@oeway
Copy link

oeway commented Apr 26, 2017

@lwye and others interested in my solution, here is a standalone image transform module for dense prediction tasks, I used for my project, in the end you can find some code shows how to use the module:
https://gist.github.com/oeway/2e3b989e0343f0884388ed7ed82eb3b0

It's compatible with torch's transform function interface, but there is no dependency on torch functions, so you can use it with other DL libraries as well.

You may find bug or have ideas for improvement, in that case, please comment in the gist or here.

@blackyang
Copy link

any progress for this thread?

@oeway 's method is great for e.g. segmentation task but for detection where targets are not images, I think @fmassa 's proposal about using list of transforms is more general. Will that (or something similar) be merged to core? Or as mentioned it is up to users to subclass their datasets?

@fmassa
Copy link
Member Author

fmassa commented May 8, 2017

@blackyang I need to write some tests to #115 to verify if it works properly even in multi-threaded settings, but I've been lacking time to do it lately.

@kulikovv
Copy link

kulikovv commented May 26, 2017

EDIT: modified to work with torchvision 0.7

I've solved this issue this way in my cityscape dataset wrapper:

def __getitem__(self,index):      
        img = Image.open(self.data[index]).convert('RGB')
        target = Image.open(self.data_labels[index])
        
        seed = np.random.randint(2147483647) # make a seed with numpy generator 
        random.seed(seed) # apply this seed to img tranfsorms
        torch.manual_seed(seed) # needed for torchvision 0.7
        if self.transform is not None:
            img = self.transform(img)
            
        random.seed(seed) # apply this seed to target tranfsorms
        torch.manual_seed(seed) # needed for torchvision 0.7
        if self.target_transform is not None:
            target = self.target_transform(target)

        target = torch.ByteTensor(np.array(target))
    
        return img, target

@bermanmaxim
Copy link

In my segmentation application I have made some transformations functions that either accept an (image, label) pair, or a single (image), and that also work with (transormation_image, transormation_label) pairs, allowing me to use

transforms_train = JointCompose([RandomScale(0.5, 1.5) if args.random_scale else None, 
                                     RandomHorizontalFlip() if args.random_mirror else None,
                                     RandomCropPad(input_size, (0, IGNORE_LABEL)),
                                     [None, Scale((output_size[1], output_size[0]), NEAREST)],
                                     PILtoTensor(), 
                                     [Normalize(torch.from_numpy(IMG_MEAN)), None],
                                    ])

which expect an (image, label) pair as input. I wrote the code without wanting to be generic so it needs some work, but maybe its a direction to pursue, rather than concatenating the image and label along a channel.

@perone
Copy link

perone commented Jul 2, 2017

Another important feature that would be nice to have is to have a parameter for constructors or for the entire transformation pipeline that would accept a RandomState instead of setting global seeds or using global random generators. This is a very common design mistake, also seen on Keras and many other frameworks, where they keep setting global random seed everywhere.

@fmassa
Copy link
Member Author

fmassa commented Sep 3, 2017

This issue is being addressed in #230, and a first PR was already sent in #240.
Let's continue the discussion in #230.

@lpuglia
Copy link

lpuglia commented Apr 20, 2018

For whom is still interested: I wrote a little hack that augment two images with the same transformation:

cj = transforms.ColorJitter(brightness=1.0, contrast=1.0, saturation=1.0, hue=1.0)
seed = np.random.randint(0,2**32)
np.random.seed(seed)
pl = cj(x0)
np.random.seed(seed)
pr = cj(x1)

setting the same seed of np.random will give you the same uniformly sampled values before the two color-jitters are applied.

EDIT:

apparently newer versions of TorchVision have started to use random package instead of numpy.random, solution:

cj = transforms.ColorJitter(brightness=1.0, contrast=1.0, saturation=1.0, hue=1.0)
seed = random.randint(0,2**32)
random.seed(seed)
pl = cj(x0)
random.seed(seed)
pr = cj(x1)

@SDNAFIO
Copy link

SDNAFIO commented Jul 14, 2019

@lpuglia It just should be noted that this will give you the same random seed in each worker
if you use num_workes > 1 in your DataLoader.

See also here:
https://discuss.pytorch.org/t/does-getitem-of-dataloader-reset-random-seed/8097/19

@weihaosky
Copy link

I've solved this issue this way in my cityscape dataset wrapper:

def __getitem__(self,index):      
        img = Image.open(self.data[index]).convert('RGB')
        target = Image.open(self.data_labels[index])
        
        seed = np.random.randint(2147483647) # make a seed with numpy generator 
        random.seed(seed) # apply this seed to img tranfsorms
        if self.transform is not None:
            img = self.transform(img)
            
        random.seed(seed) # apply this seed to target tranfsorms
        if self.target_transform is not None:
            target = self.target_transform(target)

        target = torch.ByteTensor(np.array(target))
    
        return img, target

It does not work for me. The colorjitter still performs different for multiple images. why?

@weihaosky
Copy link

weihaosky commented Sep 4, 2019

Figured it out. Another random seed also needs to be set.

random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)

@nik-hil
Copy link

nik-hil commented Sep 4, 2019

I find this helpful in image segmentation,

https://github.com/albu/albumentations/blob/master/notebooks/example_kaggle_salt.ipynb

@SebastienEske
Copy link

SebastienEske commented Sep 20, 2019

Hello, I would really think that if each transform had its own RNG object, we could provide it with a seed and that would solve the problem. Something along the lines of

class mytrandomransform():
    def __init__(self, seed=None):
        self.rng = CMWC() #or any other RNG object
        if seed!=None:
            self.rng.seed(seed)
    def run(self):
        return self.rng.random()
a = mytrans(2)
b = mytrans(2)

Here is a more complete code example
And this is a fully backward compatible change since we only add one optional parameter when creating the transform.

@ginobilinie
Copy link

I gave it a shot on implementing some generic random transform that could be applied so several inputs (even for example images + bounding boxes). An example implementation can be found here.
Here are the main ideas:

  • Allow to pass an arbitrarily nested tuple of inputs (can be image, labels, bboxes, masks, etc), and the Compose transform can accept the tuple instead of individual elements
  • The sequence of transforms should be a list as before, but if an entry is a list, then it is applied to the corresponding instance of the tuple of inputs, and if it's not, then it's applied to all the elements at the same time.
  • Every random transform should consist of 2 classes: a random generator, that generates the elements during the function call, and the transform itself, that can accept the generator object and uses the random parameters generated from it. It can thus allow the same transformation to be applied to different inputs by reusing the same implementation, and it's not needed to implement say RandomCrop2Images etc.
  • It's backward compatible with previous transforms.

It's a bit rough (and maybe a bit too complicated), but it could handle the cases we have mentioned in this thread.
Yes, the main issue is if you use transforms.compose, then the call function of your transform operation can only take two positional arguments, one is self, another can be a tuple, for example, (img, gt). But it seems it can accept more than 1 returns.

@LucaMarconato
Copy link

Another possibility without having to change the seed it to save and load the state. In this example I transform both the origin and the target applying the same instance of a random transformation. For me the only seed making a change was the torch seed, not the random one or the numpy one.

t = transforms.RandomRotation(degrees=360)
state = torch.get_rng_state()
x = t(x)
torch.set_rng_state(state)
y = t(y)

@MarioGalindoQ
Copy link

I think I have a simple solution:
If the images are concatenated, the transformations are applied to all of them identically:

import torch
import torchvision.transforms as T

# Create two fake images (identical for test purposes):
image = torch.randn((3, 128, 128))
target = image.clone()

# This is the trick (concatenate the images):
both_images = torch.cat((image.unsqueeze(0), target.unsqueeze(0)),0)

# Apply the transformations to both images simultaneously:
transformed_images = T.RandomRotation(180)(both_images)

# Get the transformed images:
image_trans = transformed_images[0]
target_trans = transformed_images[1]

# Compare the transformed images:
torch.all(image_trans == target_trans).item()

>> True

rajveerb pushed a commit to rajveerb/vision that referenced this issue Nov 30, 2023
* [llm] Init draft NVIDIA reference

* [LLM] Add exact HPs used to match NVIDIA's convergence curves

* [LLM] Add data preprocessing steps and remove dropout

* [LLM] fix eval, add ckpt load util, remove unnecessary files

* [LLM] Update data preprocessing stage in README

* Full validation and google settings

* Apply review comments

* Anmolgupt/nvidia llm reference update (pytorch#3)

* Update Nvidia LLM reference code version

Co-authored-by: Anmol Gupta <[email protected]>

* fixes to imports (pytorch#5)

Co-authored-by: Anmol Gupta <[email protected]>

* distributed checkpoint and mlperf logger support (pytorch#6)

* readme and mllogger keywords update (pytorch#7)

Co-authored-by: Anmol Gupta <[email protected]>

* Update fp32_checkpoint_checksum.log

* Update README.md

* Update README.md

* Update README.md

* mlperf logger keywords update (pytorch#8)

Co-authored-by: Anmol Gupta <[email protected]>

* [LLM] Create framework folder

* [LLM] Update README to follow reference template

* Describe LLM checkpoint format in README (pytorch#9)

Describe LLM checkpoint format in README

* [LLM] Readme updates, small fixes

* readme update and run script eval update (pytorch#10)

Co-authored-by: Anmol Gupta <[email protected]>

---------

Co-authored-by: Mikołaj Błaż <[email protected]>
Co-authored-by: anmolgupt <[email protected]>
Co-authored-by: Anmol Gupta <[email protected]>
Co-authored-by: mikolajblaz <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests