diff --git a/torchvision/datasets/fakedata.py b/torchvision/datasets/fakedata.py index ccd99e03ed0..f079c1a92db 100644 --- a/torchvision/datasets/fakedata.py +++ b/torchvision/datasets/fakedata.py @@ -21,14 +21,11 @@ class FakeData(VisionDataset): def __init__(self, size=1000, image_size=(3, 224, 224), num_classes=10, transform=None, target_transform=None, random_offset=0): - super(FakeData, self).__init__(None) - self.transform = transform - self.target_transform = target_transform + super(FakeData, self).__init__(None, transform=transform, + target_transform=target_transform) self.size = size self.num_classes = num_classes self.image_size = image_size - self.transform = transform - self.target_transform = target_transform self.random_offset = random_offset def __getitem__(self, index):