Skip to content

Commit

Permalink
bugfix: batch_size parameter for DataModules remaining (Lightning-Uni…
Browse files Browse the repository at this point in the history
…verse#344)

* bugfix: batch_size for DataModules remaining

* Update sklearn datamodule tests

* Fix default_transforms. Keep internal for every data module

* fix typo on binary_mnist_datamodule

thanks @akihironitta

Co-authored-by: Akihiro Nitta <[email protected]>

Co-authored-by: Akihiro Nitta <[email protected]>
  • Loading branch information
2 people authored and chris-clem committed Dec 9, 2020
1 parent 9a18028 commit cd94b9e
Show file tree
Hide file tree
Showing 8 changed files with 57 additions and 49 deletions.
4 changes: 2 additions & 2 deletions docs/source/introduction_guide.rst
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ Here's an example for logistic regression
# use any numpy or sklearn dataset
X, y = load_iris(return_X_y=True)
dm = SklearnDataModule(X, y)
dm = SklearnDataModule(X, y, batch_size=12)
# build model
model = LogisticRegression(input_dim=4, num_classes=3)
Expand All @@ -434,7 +434,7 @@ Here's an example for logistic regression
trainer = pl.Trainer(tpu_cores=8, precision=16)
trainer.fit(model, dm.train_dataloader(), dm.val_dataloader())
trainer.test(test_dataloaders=dm.test_dataloader(batch_size=12))
trainer.test(test_dataloaders=dm.test_dataloader())
Any input will be flattened across all dimensions except the first one (batch).
This means images, sound, etc... work out of the box.
Expand Down
16 changes: 8 additions & 8 deletions pl_bolts/datamodules/cityscapes_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,8 @@ def train_dataloader(self):
"""
Cityscapes train set
"""
transforms = self.train_transforms or self.default_transforms()
target_transforms = self.target_transforms or self.default_target_transforms()
transforms = self.train_transforms or self._default_transforms()
target_transforms = self.target_transforms or self._default_target_transforms()

dataset = Cityscapes(self.data_dir,
split='train',
Expand All @@ -136,8 +136,8 @@ def val_dataloader(self):
"""
Cityscapes val set
"""
transforms = self.val_transforms or self.default_transforms()
target_transforms = self.target_transforms or self.default_target_transforms()
transforms = self.val_transforms or self._default_transforms()
target_transforms = self.target_transforms or self._default_target_transforms()

dataset = Cityscapes(self.data_dir,
split='val',
Expand All @@ -161,8 +161,8 @@ def test_dataloader(self):
"""
Cityscapes test set
"""
transforms = self.test_transforms or self.default_transforms()
target_transforms = self.target_transforms or self.default_target_transforms()
transforms = self.test_transforms or self._default_transforms()
target_transforms = self.target_transforms or self._default_target_transforms()

dataset = Cityscapes(self.data_dir,
split='test',
Expand All @@ -181,7 +181,7 @@ def test_dataloader(self):
)
return loader

def default_transforms(self):
def _default_transforms(self):
cityscapes_transforms = transform_lib.Compose([
transform_lib.ToTensor(),
transform_lib.Normalize(
Expand All @@ -191,7 +191,7 @@ def default_transforms(self):
])
return cityscapes_transforms

def default_target_transforms(self):
def _default_target_transforms(self):
cityscapes_target_trasnforms = transform_lib.Compose([
transform_lib.ToTensor(),
transform_lib.Lambda(lambda t: t.squeeze())
Expand Down
16 changes: 9 additions & 7 deletions pl_bolts/datamodules/kitti_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,8 @@ def __init__(
self.num_workers = num_workers
self.seed = seed

self.default_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753],
std=[0.32064945, 0.32098866, 0.32325324])
])

# split into train, val, test
kitti_dataset = KittiDataset(self.data_dir, transform=self.default_transforms)
kitti_dataset = KittiDataset(self.data_dir, transform=self._default_transforms())

val_len = round(val_split * len(kitti_dataset))
test_len = round(test_split * len(kitti_dataset))
Expand Down Expand Up @@ -111,3 +105,11 @@ def test_dataloader(self):
shuffle=False,
num_workers=self.num_workers)
return loader

def _default_transforms(self):
kitti_transforms = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753],
std=[0.32064945, 0.32098866, 0.32325324])
])
return kitti_transforms
22 changes: 12 additions & 10 deletions pl_bolts/datamodules/sklearn_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,22 +118,22 @@ class SklearnDataModule(LightningDataModule):
>>> from pl_bolts.datamodules import SklearnDataModule
...
>>> X, y = load_boston(return_X_y=True)
>>> loaders = SklearnDataModule(X, y)
>>> loaders = SklearnDataModule(X, y, batch_size=32)
...
>>> # train set
>>> train_loader = loaders.train_dataloader(batch_size=32)
>>> train_loader = loaders.train_dataloader()
>>> len(train_loader.dataset)
355
>>> len(train_loader)
11
>>> # validation set
>>> val_loader = loaders.val_dataloader(batch_size=32)
>>> val_loader = loaders.val_dataloader()
>>> len(val_loader.dataset)
100
>>> len(val_loader)
3
>>> # test set
>>> test_loader = loaders.test_dataloader(batch_size=32)
>>> test_loader = loaders.test_dataloader()
>>> len(test_loader.dataset)
51
>>> len(test_loader)
Expand All @@ -150,12 +150,14 @@ def __init__(
num_workers=2,
random_state=1234,
shuffle=True,
batch_size: int = 16,
*args,
**kwargs,
):

super().__init__(*args, **kwargs)
self.num_workers = num_workers
self.batch_size = batch_size

# shuffle x and y
if shuffle and _SKLEARN_AVAILABLE:
Expand Down Expand Up @@ -193,32 +195,32 @@ def _init_datasets(self, X, y, x_val, y_val, x_test, y_test):
self.val_dataset = SklearnDataset(x_val, y_val)
self.test_dataset = SklearnDataset(x_test, y_test)

def train_dataloader(self, batch_size: int = 16):
def train_dataloader(self):
loader = DataLoader(
self.train_dataset,
batch_size=batch_size,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers,
drop_last=True,
pin_memory=True
)
return loader

def val_dataloader(self, batch_size: int = 16):
def val_dataloader(self):
loader = DataLoader(
self.val_dataset,
batch_size=batch_size,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
drop_last=True,
pin_memory=True
)
return loader

def test_dataloader(self, batch_size: int = 16):
def test_dataloader(self):
loader = DataLoader(
self.test_dataset,
batch_size=batch_size,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
drop_last=True,
Expand Down
14 changes: 8 additions & 6 deletions pl_bolts/datamodules/ssl_imagenet_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __init__(
data_dir,
meta_dir=None,
num_workers=16,
batch_size: int = 32,
*args,
**kwargs,
):
Expand All @@ -39,6 +40,7 @@ def __init__(
self.data_dir = data_dir
self.num_workers = num_workers
self.meta_dir = meta_dir
self.batch_size = batch_size

@property
def num_classes(self):
Expand Down Expand Up @@ -74,7 +76,7 @@ def prepare_data(self):
UnlabeledImagenet.generate_meta_bins(path)
""")

def train_dataloader(self, batch_size, num_images_per_class=-1, add_normalize=False):
def train_dataloader(self, num_images_per_class=-1, add_normalize=False):
transforms = self._default_transforms() if self.train_transforms is None else self.train_transforms

dataset = UnlabeledImagenet(self.data_dir,
Expand All @@ -84,15 +86,15 @@ def train_dataloader(self, batch_size, num_images_per_class=-1, add_normalize=Fa
transform=transforms)
loader = DataLoader(
dataset,
batch_size=batch_size,
batch_size=self.batch_size,
shuffle=True,
num_workers=self.num_workers,
drop_last=True,
pin_memory=True
)
return loader

def val_dataloader(self, batch_size, num_images_per_class=50, add_normalize=False):
def val_dataloader(self, num_images_per_class=50, add_normalize=False):
transforms = self._default_transforms() if self.val_transforms is None else self.val_transforms

dataset = UnlabeledImagenet(self.data_dir,
Expand All @@ -102,14 +104,14 @@ def val_dataloader(self, batch_size, num_images_per_class=50, add_normalize=Fals
transform=transforms)
loader = DataLoader(
dataset,
batch_size=batch_size,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
pin_memory=True
)
return loader

def test_dataloader(self, batch_size, num_images_per_class, add_normalize=False):
def test_dataloader(self, num_images_per_class, add_normalize=False):
transforms = self._default_transforms() if self.test_transforms is None else self.test_transforms

dataset = UnlabeledImagenet(self.data_dir,
Expand All @@ -119,7 +121,7 @@ def test_dataloader(self, batch_size, num_images_per_class, add_normalize=False)
transform=transforms)
loader = DataLoader(
dataset,
batch_size=batch_size,
batch_size=self.batch_size,
shuffle=False,
num_workers=self.num_workers,
drop_last=True,
Expand Down
16 changes: 8 additions & 8 deletions pl_bolts/datamodules/stl10_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def train_dataloader(self):
"""
Loads the 'unlabeled' split minus a portion set aside for validation via `unlabeled_val_split`.
"""
transforms = self.default_transforms() if self.train_transforms is None else self.train_transforms
transforms = self._default_transforms() if self.train_transforms is None else self.train_transforms

dataset = STL10(self.data_dir, split='unlabeled', download=False, transform=transforms)
train_length = len(dataset)
Expand All @@ -132,7 +132,7 @@ def train_dataloader_mixed(self):
batch_size: the batch size
transforms: a sequence of transforms
"""
transforms = self.default_transforms() if self.train_transforms is None else self.train_transforms
transforms = self._default_transforms() if self.train_transforms is None else self.train_transforms

unlabeled_dataset = STL10(self.data_dir,
split='unlabeled',
Expand Down Expand Up @@ -170,7 +170,7 @@ def val_dataloader(self):
batch_size: the batch size
transforms: a sequence of transforms
"""
transforms = self.default_transforms() if self.val_transforms is None else self.val_transforms
transforms = self._default_transforms() if self.val_transforms is None else self.val_transforms

dataset = STL10(self.data_dir, split='unlabeled', download=False, transform=transforms)
train_length = len(dataset)
Expand Down Expand Up @@ -202,7 +202,7 @@ def val_dataloader_mixed(self):
batch_size: the batch size
transforms: a sequence of transforms
"""
transforms = self.default_transforms() if self.val_transforms is None else self.val_transforms
transforms = self._default_transforms() if self.val_transforms is None else self.val_transforms
unlabeled_dataset = STL10(self.data_dir,
split='unlabeled',
download=False,
Expand Down Expand Up @@ -237,7 +237,7 @@ def test_dataloader(self):
batch_size: the batch size
transforms: the transforms
"""
transforms = self.default_transforms() if self.test_transforms is None else self.test_transforms
transforms = self._default_transforms() if self.test_transforms is None else self.test_transforms

dataset = STL10(self.data_dir, split='test', download=False, transform=transforms)
loader = DataLoader(
Expand All @@ -251,7 +251,7 @@ def test_dataloader(self):
return loader

def train_dataloader_labeled(self):
transforms = self.default_transforms() if self.val_transforms is None else self.val_transforms
transforms = self._default_transforms() if self.val_transforms is None else self.val_transforms

dataset = STL10(self.data_dir, split='train', download=False, transform=transforms)
train_length = len(dataset)
Expand All @@ -268,7 +268,7 @@ def train_dataloader_labeled(self):
return loader

def val_dataloader_labeled(self):
transforms = self.default_transforms() if self.val_transforms is None else self.val_transforms
transforms = self._default_transforms() if self.val_transforms is None else self.val_transforms
dataset = STL10(self.data_dir,
split='train',
download=False,
Expand All @@ -288,7 +288,7 @@ def val_dataloader_labeled(self):
)
return loader

def default_transforms(self):
def _default_transforms(self):
data_transforms = transform_lib.Compose([
transform_lib.ToTensor(),
stl10_normalization()
Expand Down
9 changes: 5 additions & 4 deletions pl_bolts/models/regression/linear_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,9 +131,6 @@ def cli_main():
'You want to use `sklearn` which is not installed yet, install it with `pip install sklearn`.'
) from err

X, y = load_boston(return_X_y=True) # these are numpy arrays
loaders = SklearnDataModule(X, y)

# args
parser = ArgumentParser()
parser = LinearRegression.add_model_specific_args(parser)
Expand All @@ -144,9 +141,13 @@ def cli_main():
model = LinearRegression(input_dim=13, l1_strength=1, l2_strength=1)
# model = LinearRegression(**vars(args))

# data
X, y = load_boston(return_X_y=True) # these are numpy arrays
loaders = SklearnDataModule(X, y, batch_size=args.batch_size)

# train
trainer = pl.Trainer.from_argparse_args(args)
trainer.fit(model, loaders.train_dataloader(args.batch_size), loaders.val_dataloader(args.batch_size))
trainer.fit(model, loaders.train_dataloader(), loaders.val_dataloader())


if __name__ == '__main__':
Expand Down
9 changes: 5 additions & 4 deletions pl_bolts/models/regression/logistic_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,6 @@ def cli_main():
'You want to use `sklearn` which is not installed yet, install it with `pip install sklearn`.'
) from err

X, y = load_iris(return_X_y=True)
loaders = SklearnDataModule(X, y)

# args
parser = ArgumentParser()
parser = LogisticRegression.add_model_specific_args(parser)
Expand All @@ -150,9 +147,13 @@ def cli_main():
# model = LogisticRegression(**vars(args))
model = LogisticRegression(input_dim=4, num_classes=3, l1_strength=0.01, learning_rate=0.01)

# data
X, y = load_iris(return_X_y=True)
loaders = SklearnDataModule(X, y, batch_size=args.batch_size)

# train
trainer = pl.Trainer.from_argparse_args(args)
trainer.fit(model, loaders.train_dataloader(args.batch_size), loaders.val_dataloader(args.batch_size))
trainer.fit(model, loaders.train_dataloader(), loaders.val_dataloader())


if __name__ == '__main__':
Expand Down

0 comments on commit cd94b9e

Please sign in to comment.