Skip to content

Commit

Permalink
chore(components/pytorch): Samples - Cifar10 webdataset fix (#6516)
Browse files Browse the repository at this point in the history
Signed-off-by: Shrinath Suresh <[email protected]>
  • Loading branch information
shrinath-suresh authored Sep 8, 2021
1 parent 5c2b529 commit 9ed77f5
Showing 1 changed file with 32 additions and 45 deletions.
77 changes: 32 additions & 45 deletions samples/contrib/pytorch-samples/cifar10/cifar10_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

class CIFAR10DataModule(pl.LightningDataModule): # pylint: disable=too-many-instance-attributes
"""Data module class."""

def __init__(self, **kwargs):
"""Initialization of inherited lightning data module."""
super(CIFAR10DataModule, self).__init__() # pylint: disable=super-with-arguments
Expand All @@ -33,9 +32,8 @@ def __init__(self, **kwargs):
self.train_data_loader = None
self.val_data_loader = None
self.test_data_loader = None
self.normalize = transforms.Normalize(
mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]
)
self.normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5])
self.valid_transform = transforms.Compose([
transforms.ToTensor(),
self.normalize,
Expand Down Expand Up @@ -79,50 +77,39 @@ def setup(self, stage=None):
val_count = self.get_num_files(val_base_url)
test_count = self.get_num_files(test_base_url)

train_url = "{}/{}-{}".format(
train_base_url, "train", "{0.." + str(train_count) + "}.tar"
)
valid_url = "{}/{}-{}".format(
val_base_url, "val", "{0.." + str(val_count) + "}.tar"
)
test_url = "{}/{}-{}".format(
test_base_url, "test", "{0.." + str(test_count) + "}.tar"
)

self.train_dataset = (
wds.Dataset(
train_url, handler=wds.warn_and_continue, length=40000 // 40
).shuffle(100).decode("pil").rename(
image="ppm;jpg;jpeg;png", info="cls"
).map_dict(image=self.train_transform).to_tuple("image",
"info").batched(40)
)

self.valid_dataset = (
wds.Dataset(
valid_url, handler=wds.warn_and_continue, length=10000 // 20
).shuffle(100).decode("pil").rename(image="ppm",
info="cls").map_dict(
image=self.valid_transform
).to_tuple("image",
"info").batched(20)
)

self.test_dataset = (
wds.Dataset(
test_url, handler=wds.warn_and_continue, length=10000 // 20
).shuffle(100).decode("pil").rename(image="ppm",
info="cls").map_dict(
image=self.valid_transform
).to_tuple("image",
"info").batched(20)
)
train_url = "{}/{}-{}".format(train_base_url, "train",
"{0.." + str(train_count) + "}.tar")
valid_url = "{}/{}-{}".format(val_base_url, "val",
"{0.." + str(val_count) + "}.tar")
test_url = "{}/{}-{}".format(test_base_url, "test",
"{0.." + str(test_count) + "}.tar")

self.train_dataset = (wds.WebDataset(
train_url,
handler=wds.warn_and_continue).shuffle(100).decode("pil").rename(
image="ppm;jpg;jpeg;png",
info="cls").map_dict(image=self.train_transform).to_tuple(
"image", "info").batched(40))

self.valid_dataset = (wds.WebDataset(
valid_url,
handler=wds.warn_and_continue).shuffle(100).decode("pil").rename(
image="ppm",
info="cls").map_dict(image=self.valid_transform).to_tuple(
"image", "info").batched(20))

self.test_dataset = (wds.WebDataset(
test_url,
handler=wds.warn_and_continue).shuffle(100).decode("pil").rename(
image="ppm",
info="cls").map_dict(image=self.valid_transform).to_tuple(
"image", "info").batched(20))

def create_data_loader(self, dataset, batch_size, num_workers): # pylint: disable=no-self-use
"""Creates data loader."""
return DataLoader(
dataset, batch_size=batch_size, num_workers=num_workers
)
return DataLoader(dataset,
batch_size=batch_size,
num_workers=num_workers)

def train_dataloader(self):
"""Train Data loader.
Expand Down

0 comments on commit 9ed77f5

Please sign in to comment.