From c60e142905cf44086d5c3e1ed844953730668f98 Mon Sep 17 00:00:00 2001 From: Nathan Raw Date: Tue, 22 Sep 2020 20:25:52 -0600 Subject: [PATCH 01/32] Add lightning argument parser (#246) * :construction: copy code over from lightning PR #3537 * :white_check_mark: tests * :clown_face: add dummy commit to see if it fixes checks * :lipstick: style * :lipstick: style * :lipstick: style * :wrench: config * :wrench: config * :pencil: docs * :wrench: skip utils as docs fail on dataclass * :pencil: add docs * :pencil: docs * :white_check_mark: add test --- docs/source/conf.py | 1 + docs/source/index.rst | 1 - pl_bolts/utils/arguments.py | 135 ++++++++++++++++++++++++++++++++++ setup.cfg | 1 + tests/utils/test_arguments.py | 93 +++++++++++++++++++++++ 5 files changed, 230 insertions(+), 1 deletion(-) create mode 100644 pl_bolts/utils/arguments.py create mode 100644 tests/utils/test_arguments.py diff --git a/docs/source/conf.py b/docs/source/conf.py index dbe5f4b461..f2dc1442d9 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -135,6 +135,7 @@ 'api/pl_bolts.rst', 'api/modules.rst', 'api/pl_bolts.submit.rst', + 'api/pl_bolts.utils.*', 'PULL_REQUEST_TEMPLATE.md', ] diff --git a/docs/source/index.rst b/docs/source/index.rst index 26a3192fbe..001990c191 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -96,4 +96,3 @@ Indices and tables api/pl_bolts.losses api/pl_bolts.optimizers api/pl_bolts.transforms - api/pl_bolts.utils diff --git a/pl_bolts/utils/arguments.py b/pl_bolts/utils/arguments.py new file mode 100644 index 0000000000..fd4205cd5e --- /dev/null +++ b/pl_bolts/utils/arguments.py @@ -0,0 +1,135 @@ +import inspect +from argparse import ArgumentParser, Namespace +from dataclasses import dataclass +from typing import Any, Optional + +import pytorch_lightning as pl + + +@dataclass(frozen=True) +class LitArg: + """Dataclass to represent init args of an object + """ + name: str + types: tuple + default: Any + required: bool = False + context: Optional[str] = None + + +class LightningArgumentParser(ArgumentParser): + def __init__(self, *args, ignore_required_init_args=True, **kwargs): + """Extension of argparse.ArgumentParser that lets you parse arbitrary object init args. + + Example:: + + from pl_bolts.utils.arguments import LightningArgumentParser + + parser.add_object_args("data", MyDataModule) + parser.add_object_args("model", MyModel) + args = parser.parse_lit_args() + + # args.data -> data args + # args.model -> model args + + Args: + ignore_required_init_args (bool, optional): Whether to include positional args when adding + object args. Defaults to True. + """ + super().__init__(*args, **kwargs) + self.ignore_required_init_args = ignore_required_init_args + + self._default_obj_args = dict() + self._added_arg_names = [] + + def add_object_args(self, name, obj): + default_args = gather_lit_args(obj) + self._default_obj_args[name] = default_args + for arg in default_args: + if arg.name in self._added_arg_names: + continue + self._added_arg_names.append(arg.name) + kwargs = dict(type=arg.types[0]) + if arg.required and not self.ignore_required_init_args: + kwargs["required"] = True + else: + kwargs["default"] = arg.default + self.add_argument(f"--{arg.name}", **kwargs) + + def parse_lit_args(self, *args, **kwargs): + parsed_args_dict = vars(self.parse_args(*args, **kwargs)) + lit_args = Namespace() + for name, default_args in self._default_obj_args.items(): + lit_obj_args = dict() + for arg in default_args: + arg_is_member_of_obj = arg.name in parsed_args_dict + arg_should_be_added = not arg.required or (arg.required and not self.ignore_required_init_args) + if arg_is_member_of_obj and arg_should_be_added: + lit_obj_args[arg.name] = parsed_args_dict[arg.name] + lit_args.__dict__.update(**{name: Namespace(**lit_obj_args)}) + return lit_args + + +def gather_lit_args(cls, root_cls=None): + + if root_cls is None: + if issubclass(cls, pl.LightningModule): + root_cls = pl.LightningModule + elif issubclass(cls, pl.LightningDataModule): + root_cls = pl.LightningDataModule + else: + root_cls = cls + + blacklisted_args = ["self", "args", "kwargs"] + arguments = [] + argument_names = [] + for obj in inspect.getmro(cls): + + if obj is root_cls and len(arguments) > 0: + break + + if issubclass(obj, root_cls): + + default_params = inspect.signature(obj.__init__).parameters + + for arg in default_params: + arg_type = default_params[arg].annotation + arg_default = default_params[arg].default + + try: + arg_types = tuple(arg_type.__args__) + except AttributeError: + arg_types = (arg_type,) + + # If type is empty, that means it hasn't been given type hint. We skip these. + arg_is_missing_type_hint = arg_types == (inspect._empty,) + # Some args should be ignored by default (self, kwargs, args) + arg_is_in_blacklist = arg in blacklisted_args and arg_is_missing_type_hint + # We only keep the first arg we see of a given name, as it overrides the parents + arg_is_duplicate = arg in argument_names + # We skip any of the above 3 cases + do_skip_this_arg = arg_is_in_blacklist or arg_is_missing_type_hint or arg_is_duplicate + + # Positional args have no default, but do have a known type or types. + arg_is_positional = arg_default == inspect._empty and not arg_is_missing_type_hint + # Kwargs have both a default + known type or types + arg_is_kwarg = arg_default != inspect._empty and not arg_is_missing_type_hint + + if do_skip_this_arg: + continue + + elif arg_is_positional or arg_is_kwarg: + lit_arg = LitArg( + name=arg, + types=arg_types, + default=arg_default if arg_default != inspect._empty else None, + required=arg_is_positional, + context=obj.__name__, + ) + arguments.append(lit_arg) + argument_names.append(arg) + else: + raise RuntimeError( + f"Could not determine proper grouping of argument '{arg}' while gathering LitArgs" + ) + return arguments diff --git a/setup.cfg b/setup.cfg index 644db29746..15be4224ab 100644 --- a/setup.cfg +++ b/setup.cfg @@ -41,6 +41,7 @@ format = pylint # see: https://www.flake8rules.com/ ignore = E731 # Do not assign a lambda expression, use a def + E231 # Ignore missing space after comma W504 # Line break occurred after a binary operator F401 # Module imported but unused F841 # Local variable name is assigned to but never used diff --git a/tests/utils/test_arguments.py b/tests/utils/test_arguments.py new file mode 100644 index 0000000000..a2d0581a8f --- /dev/null +++ b/tests/utils/test_arguments.py @@ -0,0 +1,93 @@ +from dataclasses import FrozenInstanceError + +import pytest +import pytorch_lightning as pl + +from pl_bolts.utils.arguments import LightningArgumentParser, LitArg, gather_lit_args + + +class DummyParentModel(pl.LightningModule): + + name = "parent-model" + + def __init__(self, a: int, b: str, c: str = "parent_model_c"): + super().__init__() + self.save_hyperparameters() + + def forward(self, x): + pass + + +class DummyParentDataModule(pl.LightningDataModule): + + name = "parent-dm" + + def __init__(self, d: str, c: str = "parent_dm_c"): + super().__init__() + self.d = d + self.c = c + + +def test_lightning_argument_parser(): + parser = LightningArgumentParser(ignore_required_init_args=False) + assert parser.ignore_required_init_args is False + parser = LightningArgumentParser(ignore_required_init_args=True) + assert parser.ignore_required_init_args is True + + +@pytest.mark.xfail() +def test_parser_bad_argument(): + parser = LightningArgumentParser() + parser.add_object_args('dm', DummyParentDataModule) + parser.add_object_args('model', DummyParentModel) + args = parser.parse_lit_args(['--some-bad-arg', 'asdf']) + + +def test_lit_arg_immutable(): + arg = LitArg("some_arg", (int,), 1, False) + with pytest.raises(FrozenInstanceError): + arg.default = 0 + assert arg.default == 1 + + +@pytest.mark.parametrize( + "obj,expected", + [ + pytest.param( + DummyParentModel, + {"a": (int, None, True), "b": (str, None, True), "c": (str, "parent_model_c", False),}, + id="dummy-parent-model", + ), + pytest.param( + DummyParentDataModule, {"d": (str, None, True), "c": (str, "parent_dm_c", False)}, id="dummy-parent-dm", + ), + ], +) +def test_gather_lit_args(obj, expected): + lit_args = gather_lit_args(obj) + assert len(lit_args) == len(expected) + for lit_arg, (k, v) in zip(lit_args, expected.items()): + assert lit_arg.name == k + assert lit_arg.types[0] == v[0] + assert lit_arg.default == v[1] + assert lit_arg.required == v[2] + + +@pytest.mark.parametrize( + "ignore_required_init_args,dm_cls,model_cls,a,b,c,d", + [pytest.param(True, DummyParentDataModule, DummyParentModel, 999, "bbb", "ccc", "ddd", id="base",),], +) +def test_lightning_arguments(ignore_required_init_args, dm_cls, model_cls, a, b, c, d): + parser = LightningArgumentParser(ignore_required_init_args=ignore_required_init_args) + parser.add_object_args("dm", dm_cls) + parser.add_object_args("model", model_cls) + + mocked_args = f""" + --a 1 + --b {b} + --c {c} + --d {d} + """.strip().split() + args = parser.parse_lit_args(mocked_args) + assert vars(args.dm)["c"] == vars(args.model)["c"] == c + assert "d" not in args.dm From 32139f469e250c20a0921f9e877cfa530556af17 Mon Sep 17 00:00:00 2001 From: Annika Brundyn <42869932+annikabrundyn@users.noreply.github.com> Date: Sat, 26 Sep 2020 22:53:48 -0400 Subject: [PATCH 02/32] Kitti Datamodule (#248) * kitti dataset * kitti dataset * kitti dm * kitti dm * imports * kitti * kitti * kitti * kitti * kitti * kitti * kitti * kitti * kitti * kitti * kitti * kitti * kitti * kitti * kitti * kitti * kitti * kitti * kitti * kitti * kitti * kitti * kitti --- pl_bolts/datamodules/__init__.py | 3 + pl_bolts/datamodules/kitti_datamodule.py | 99 ++++++++++++++++++++++++ pl_bolts/datamodules/kitti_dataset.py | 92 ++++++++++++++++++++++ 3 files changed, 194 insertions(+) create mode 100644 pl_bolts/datamodules/kitti_datamodule.py create mode 100644 pl_bolts/datamodules/kitti_dataset.py diff --git a/pl_bolts/datamodules/__init__.py b/pl_bolts/datamodules/__init__.py index 2ac28cd8e6..e8de3eaf8d 100644 --- a/pl_bolts/datamodules/__init__.py +++ b/pl_bolts/datamodules/__init__.py @@ -18,5 +18,8 @@ from pl_bolts.datamodules.ssl_imagenet_datamodule import SSLImagenetDataModule from pl_bolts.datamodules.stl10_datamodule import STL10DataModule from pl_bolts.datamodules.vocdetection_datamodule import VOCDetectionDataModule + + from pl_bolts.datamodules.kitti_dataset import KittiDataset + from pl_bolts.datamodules.kitti_datamodule import KittiDataModule except ImportError: pass diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py new file mode 100644 index 0000000000..6858af6629 --- /dev/null +++ b/pl_bolts/datamodules/kitti_datamodule.py @@ -0,0 +1,99 @@ +import os +import torch + +from pytorch_lightning import LightningDataModule +from pl_bolts.datamodules.kitti_dataset import KittiDataset + +from torch.utils.data import DataLoader +import torchvision.transforms as transforms +from torch.utils.data.dataset import random_split + + +class KittiDataModule(LightningDataModule): + + name = 'kitti' + + def __init__( + self, + data_dir: str, + val_split: float = 0.2, + test_split: float = 0.1, + num_workers: int = 16, + batch_size: int = 32, + seed: int = 42, + *args, + **kwargs, + ): + """ + Kitti train, validation and test dataloaders. + + Note: You need to have downloaded the Kitti dataset first and provide the path to where it is saved. + You can download the dataset here: http://www.cvlibs.net/datasets/kitti/eval_semseg.php?benchmark=semantics2015 + + Specs: + - 200 samples + - Each image is (3 x 1242 x 376) + + In total there are 34 classes but some of these are not useful so by default we use only 19 of the classes + specified by the `valid_labels` parameter. + + Example:: + + from pl_bolts.datamodules import KittiDataModule + + dm = KittiDataModule(PATH) + model = LitModel() + + Trainer().fit(model, dm) + + Args:: + data_dir: where to load the data from path, i.e. '/path/to/folder/with/data_semantics/' + val_split: size of validation test (default 0.2) + test_split: size of test set (default 0.1) + num_workers: how many workers to use for loading data + batch_size: the batch size + seed: random seed to be used for train/val/test splits + """ + super().__init__(*args, **kwargs) + self.data_dir = data_dir if data_dir is not None else os.getcwd() + self.batch_size = batch_size + self.num_workers = num_workers + self.seed = seed + + self.default_transforms = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize(mean=[0.35675976, 0.37380189, 0.3764753], + std=[0.32064945, 0.32098866, 0.32325324]) + ]) + + # split into train, val, test + kitti_dataset = KittiDataset(self.data_dir, transform=self.default_transforms) + + val_len = round(val_split * len(kitti_dataset)) + test_len = round(test_split * len(kitti_dataset)) + train_len = len(kitti_dataset) - val_len - test_len + + self.trainset, self.valset, self.testset = random_split(kitti_dataset, + lengths=[train_len, val_len, test_len], + generator=torch.Generator().manual_seed(self.seed)) + + def train_dataloader(self): + loader = DataLoader(self.trainset, + batch_size=self.batch_size, + shuffle=True, + num_workers=self.num_workers) + return loader + + def val_dataloader(self): + loader = DataLoader(self.valset, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers) + return loader + + def test_dataloader(self): + loader = DataLoader(self.testset, + batch_size=self.batch_size, + shuffle=False, + num_workers=self.num_workers) + return loader diff --git a/pl_bolts/datamodules/kitti_dataset.py b/pl_bolts/datamodules/kitti_dataset.py new file mode 100644 index 0000000000..937f106fa2 --- /dev/null +++ b/pl_bolts/datamodules/kitti_dataset.py @@ -0,0 +1,92 @@ +import os +import numpy as np +from PIL import Image + +from torch.utils.data import Dataset + +DEFAULT_VOID_LABELS = (0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1) +DEFAULT_VALID_LABELS = (7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33) + + +class KittiDataset(Dataset): + """ + Note: You need to have downloaded the Kitti dataset first and provide the path to where it is saved. + You can download the dataset here: http://www.cvlibs.net/datasets/kitti/eval_semseg.php?benchmark=semantics2015 + + There are 34 classes, however not all of them are useful for training (e.g. railings on highways). These + useless classes (the pixel values of these classes) are stored in `void_labels`. Useful classes are stored + in `valid_labels`. + + The `encode_segmap` function sets all pixels with any of the `void_labels` to `ignore_index` + (250 by default). It also sets all of the valid pixels to the appropriate value between 0 and + `len(valid_labels)` (since that is the number of valid classes), so it can be used properly by + the loss function when comparing with the output. + + Args: + data_dir (str): where to load the data from path, i.e. '/path/to/folder/with/data_semantics/' + img_size: image dimensions (width, height) + void_labels: useless classes to be excluded from training + valid_labels: useful classes to include + """ + IMAGE_PATH = os.path.join('training', 'image_2') + MASK_PATH = os.path.join('training', 'semantic') + + def __init__( + self, + data_dir: str, + img_size: tuple = (1242, 376), + void_labels: list = DEFAULT_VOID_LABELS, + valid_labels: list = DEFAULT_VALID_LABELS, + transform=None + ): + self.img_size = img_size + self.void_labels = void_labels + self.valid_labels = valid_labels + self.ignore_index = 250 + self.class_map = dict(zip(self.valid_labels, range(len(self.valid_labels)))) + self.transform = transform + + self.data_dir = data_dir + self.img_path = os.path.join(self.data_dir, self.IMAGE_PATH) + self.mask_path = os.path.join(self.data_dir, self.MASK_PATH) + self.img_list = self.get_filenames(self.img_path) + self.mask_list = self.get_filenames(self.mask_path) + + def __len__(self): + return len(self.img_list) + + def __getitem__(self, idx): + img = Image.open(self.img_list[idx]) + img = img.resize(self.img_size) + img = np.array(img) + + mask = Image.open(self.mask_list[idx]).convert('L') + mask = mask.resize(self.img_size) + mask = np.array(mask) + mask = self.encode_segmap(mask) + + if self.transform: + img = self.transform(img) + + return img, mask + + def encode_segmap(self, mask): + """ + Sets void classes to zero so they won't be considered for training + """ + for voidc in self.void_labels: + mask[mask == voidc] = self.ignore_index + for validc in self.valid_labels: + mask[mask == validc] = self.class_map[validc] + # remove extra idxs from updated dataset + mask[mask > 18] = self.ignore_index + return mask + + def get_filenames(self, path): + """ + Returns a list of absolute paths to images inside given `path` + """ + files_list = list() + for filename in os.listdir(path): + files_list.append(os.path.join(path, filename)) + return files_list From 72e8be31df3fc9b63030bf3d15b85d89a916fd03 Mon Sep 17 00:00:00 2001 From: Annika Brundyn <42869932+annikabrundyn@users.noreply.github.com> Date: Sat, 26 Sep 2020 22:54:59 -0400 Subject: [PATCH 03/32] U-net implementation (#247) * unet implementation * unet * clean up * clean up * update init * init * simple unet test * simple unet test --- pl_bolts/datamodules/__init__.py | 2 +- pl_bolts/models/__init__.py | 1 + pl_bolts/models/vision/__init__.py | 1 + pl_bolts/models/vision/unet.py | 129 +++++++++++++++++++++++++++++ tests/models/test_vision_models.py | 11 ++- 5 files changed, 141 insertions(+), 3 deletions(-) create mode 100644 pl_bolts/models/vision/unet.py diff --git a/pl_bolts/datamodules/__init__.py b/pl_bolts/datamodules/__init__.py index e8de3eaf8d..1dd2e7c9aa 100644 --- a/pl_bolts/datamodules/__init__.py +++ b/pl_bolts/datamodules/__init__.py @@ -1,5 +1,5 @@ from pl_bolts.datamodules.async_dataloader import AsynchronousLoader -from pl_bolts.datamodules.dummy_dataset import DummyDetectionDataset +from pl_bolts.datamodules.dummy_dataset import DummyDataset, DummyDetectionDataset try: from pl_bolts.datamodules.binary_mnist_datamodule import BinaryMNISTDataModule diff --git a/pl_bolts/models/__init__.py b/pl_bolts/models/__init__.py index caf9de1cba..2fae5936c7 100644 --- a/pl_bolts/models/__init__.py +++ b/pl_bolts/models/__init__.py @@ -8,3 +8,4 @@ from pl_bolts.models.regression import LinearRegression, LogisticRegression from pl_bolts.models.vision import PixelCNN from pl_bolts.models.vision.image_gpt.igpt_module import GPT2, ImageGPT +from pl_bolts.models.vision import UNet diff --git a/pl_bolts/models/vision/__init__.py b/pl_bolts/models/vision/__init__.py index 18e0784ea7..8d4ec5084e 100644 --- a/pl_bolts/models/vision/__init__.py +++ b/pl_bolts/models/vision/__init__.py @@ -1 +1,2 @@ from pl_bolts.models.vision.pixel_cnn import PixelCNN +from pl_bolts.models.vision.unet import UNet \ No newline at end of file diff --git a/pl_bolts/models/vision/unet.py b/pl_bolts/models/vision/unet.py new file mode 100644 index 0000000000..1f5bfed343 --- /dev/null +++ b/pl_bolts/models/vision/unet.py @@ -0,0 +1,129 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class UNet(nn.Module): + """ + PyTorch Lightning implementation of `U-Net: Convolutional Networks for Biomedical Image Segmentation + `_ + + Paper authors: Olaf Ronneberger, Philipp Fischer, Thomas Brox + + Model implemented by: + - `Annika Brundyn `_ + - `Akshay Kulkarni `_ + + .. warning:: Work in progress. This implementation is still being verified. + + Args: + num_classes: Number of output classes required + num_layers: Number of layers in each side of U-net (default 5) + features_start: Number of features in first layer (default 64) + bilinear (bool): Whether to use bilinear interpolation or transposed convolutions (default) for upsampling. + """ + + def __init__( + self, + num_classes: int, + num_layers: int = 5, + features_start: int = 64, + bilinear: bool = False + ): + super().__init__() + self.num_layers = num_layers + + layers = [DoubleConv(3, features_start)] + + feats = features_start + for _ in range(num_layers - 1): + layers.append(Down(feats, feats * 2)) + feats *= 2 + + for _ in range(num_layers - 1): + layers.append(Up(feats, feats // 2, bilinear)) + feats //= 2 + + layers.append(nn.Conv2d(feats, num_classes, kernel_size=1)) + + self.layers = nn.ModuleList(layers) + + def forward(self, x): + xi = [self.layers[0](x)] + # Down path + for layer in self.layers[1:self.num_layers]: + xi.append(layer(xi[-1])) + # Up path + for i, layer in enumerate(self.layers[self.num_layers:-1]): + xi[-1] = layer(xi[-1], xi[-2 - i]) + return self.layers[-1](xi[-1]) + + +class DoubleConv(nn.Module): + """ + [ Conv2d => BatchNorm (optional) => ReLU ] x 2 + """ + + def __init__(self, in_ch: int, out_ch: int): + super().__init__() + self.net = nn.Sequential( + nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), + nn.BatchNorm2d(out_ch), + nn.ReLU(inplace=True), + nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1), + nn.BatchNorm2d(out_ch), + nn.ReLU(inplace=True) + ) + + def forward(self, x): + return self.net(x) + + +class Down(nn.Module): + """ + Downscale with MaxPool => DoubleConvolution block + """ + + def __init__(self, in_ch: int, out_ch: int): + super().__init__() + self.net = nn.Sequential( + nn.MaxPool2d(kernel_size=2, stride=2), + DoubleConv(in_ch, out_ch) + ) + + def forward(self, x): + return self.net(x) + + +class Up(nn.Module): + """ + Upsampling (by either bilinear interpolation or transpose convolutions) + followed by concatenation of feature map from contracting path, + followed by DoubleConv. + """ + + def __init__(self, in_ch: int, out_ch: int, bilinear: bool = False): + super().__init__() + self.upsample = None + if bilinear: + self.upsample = nn.Sequential( + nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True), + nn.Conv2d(in_ch, in_ch // 2, kernel_size=1), + ) + else: + self.upsample = nn.ConvTranspose2d(in_ch, in_ch // 2, kernel_size=2, stride=2) + + self.conv = DoubleConv(in_ch, out_ch) + + def forward(self, x1, x2): + x1 = self.upsample(x1) + + # Pad x1 to the size of x2 + diff_h = x2.shape[2] - x1.shape[2] + diff_w = x2.shape[3] - x1.shape[3] + + x1 = F.pad(x1, [diff_w // 2, diff_w - diff_w // 2, diff_h // 2, diff_h - diff_h // 2]) + + # Concatenate along the channels axis + x = torch.cat([x2, x1], dim=1) + return self.conv(x) diff --git a/tests/models/test_vision_models.py b/tests/models/test_vision_models.py index 6894fd8456..0455a76320 100644 --- a/tests/models/test_vision_models.py +++ b/tests/models/test_vision_models.py @@ -2,8 +2,7 @@ import torch from pl_bolts.datamodules import MNISTDataModule, FashionMNISTDataModule -from pl_bolts.models import GPT2, ImageGPT - +from pl_bolts.models import GPT2, ImageGPT, UNet def test_igpt(tmpdir): pl.seed_everything(0) @@ -47,3 +46,11 @@ def test_gpt2(tmpdir): num_classes=10, ) model(x) + + +def test_unet(tmpdir): + x = torch.rand(10, 3, 28, 28) + model = UNet(num_classes=2) + y = model(x) + assert y.shape == torch.Size([10, 2, 28, 28]) + From c3e11f1520100076e1418a35149f62dcf28adf49 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 7 Oct 2020 16:46:06 +0200 Subject: [PATCH 04/32] add back RL (#257) * Revert "suspend RL (temporary) (#192)" This reverts commit df4d9284 * Explicit experience source patch (#255) * Updated DQN with implicit experience source * Updated noisy and per dqn * Updated reinforce and vpg args * Update integration tests for new args * Updated unit tests for new args and explicit train batch method * Fixed linting/docs error * space * CI * fix typing err * fix typing err * fix typing err Co-authored-by: Jirka Borovec Co-authored-by: William Falcon * RL requirements * fix imports * fi imports * skip * flake8 Co-authored-by: Donal Byrne Co-authored-by: William Falcon --- .github/workflows/ci_test-base.yml | 2 +- CHANGELOG.md | 5 + docs/source/index.rst | 1 + docs/source/losses.rst | 30 + docs/source/reinforce_learn.rst | 669 ++++++++++++++++++ pl_bolts/datamodules/__init__.py | 5 + pl_bolts/datamodules/cifar10_dataset.py | 3 + pl_bolts/datamodules/experience_source.py | 278 ++++++++ pl_bolts/losses/rl.py | 118 +++ pl_bolts/models/rl/__init__.py | 10 + pl_bolts/models/rl/common/__init__.py | 0 pl_bolts/models/rl/common/agents.py | 131 ++++ pl_bolts/models/rl/common/cli.py | 34 + pl_bolts/models/rl/common/gym_wrappers.py | 207 ++++++ pl_bolts/models/rl/common/memory.py | 313 ++++++++ pl_bolts/models/rl/common/networks.py | 317 +++++++++ pl_bolts/models/rl/double_dqn_model.py | 123 ++++ pl_bolts/models/rl/dqn_model.py | 444 ++++++++++++ pl_bolts/models/rl/dueling_dqn_model.py | 75 ++ pl_bolts/models/rl/noisy_dqn_model.py | 130 ++++ pl_bolts/models/rl/per_dqn_model.py | 197 ++++++ pl_bolts/models/rl/reinforce_model.py | 318 +++++++++ .../rl/vanilla_policy_gradient_model.py | 306 ++++++++ pl_bolts/models/vision/__init__.py | 2 +- requirements.txt | 0 requirements/models.txt | 3 +- requirements/test.txt | 3 +- tests/datamodules/test_experience_sources.py | 321 +++++++++ tests/losses/test_rl_loss.py | 51 ++ tests/models/rl/__init__.py | 0 tests/models/rl/integration/__init__.py | 0 .../rl/integration/test_policy_models.py | 41 ++ .../rl/integration/test_value_models.py | 74 ++ tests/models/rl/test_scripts.py | 104 +++ tests/models/rl/unit/__init__.py | 0 tests/models/rl/unit/test_agents.py | 62 ++ tests/models/rl/unit/test_memory.py | 286 ++++++++ tests/models/rl/unit/test_reinforce.py | 65 ++ tests/models/rl/unit/test_vpg.py | 56 ++ tests/models/rl/unit/test_wrappers.py | 19 + tests/models/test_mnist_templates.py | 6 +- tests/models/test_vision_models.py | 2 +- 42 files changed, 4802 insertions(+), 9 deletions(-) create mode 100644 docs/source/reinforce_learn.rst create mode 100644 pl_bolts/datamodules/experience_source.py create mode 100644 pl_bolts/losses/rl.py create mode 100644 pl_bolts/models/rl/__init__.py create mode 100644 pl_bolts/models/rl/common/__init__.py create mode 100644 pl_bolts/models/rl/common/agents.py create mode 100644 pl_bolts/models/rl/common/cli.py create mode 100644 pl_bolts/models/rl/common/gym_wrappers.py create mode 100644 pl_bolts/models/rl/common/memory.py create mode 100644 pl_bolts/models/rl/common/networks.py create mode 100644 pl_bolts/models/rl/double_dqn_model.py create mode 100644 pl_bolts/models/rl/dqn_model.py create mode 100644 pl_bolts/models/rl/dueling_dqn_model.py create mode 100644 pl_bolts/models/rl/noisy_dqn_model.py create mode 100644 pl_bolts/models/rl/per_dqn_model.py create mode 100644 pl_bolts/models/rl/reinforce_model.py create mode 100644 pl_bolts/models/rl/vanilla_policy_gradient_model.py create mode 100644 requirements.txt create mode 100644 tests/datamodules/test_experience_sources.py create mode 100644 tests/losses/test_rl_loss.py create mode 100644 tests/models/rl/__init__.py create mode 100644 tests/models/rl/integration/__init__.py create mode 100644 tests/models/rl/integration/test_policy_models.py create mode 100644 tests/models/rl/integration/test_value_models.py create mode 100644 tests/models/rl/test_scripts.py create mode 100644 tests/models/rl/unit/__init__.py create mode 100644 tests/models/rl/unit/test_agents.py create mode 100644 tests/models/rl/unit/test_memory.py create mode 100644 tests/models/rl/unit/test_reinforce.py create mode 100644 tests/models/rl/unit/test_vpg.py create mode 100644 tests/models/rl/unit/test_wrappers.py diff --git a/.github/workflows/ci_test-base.yml b/.github/workflows/ci_test-base.yml index 285d97b0e5..5103a13731 100644 --- a/.github/workflows/ci_test-base.yml +++ b/.github/workflows/ci_test-base.yml @@ -66,7 +66,7 @@ jobs: - name: Test Package [only] run: | # NOTE: run coverage on tests does not propagare faler status for Win, https://github.com/nedbat/coveragepy/issues/1003 - coverage run --source pl_bolts -m pytest pl_bolts -v --junitxml=junit/test-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}.xml --ignore=pl_bolts/datamodules --ignore=pl_bolts/models/self_supervised/amdim/transforms.py + coverage run --source pl_bolts -m pytest pl_bolts -v --junitxml=junit/test-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}.xml --ignore=pl_bolts/datamodules --ignore=pl_bolts/models/self_supervised/amdim/transforms.py --ignore=pl_bolts/models/rl - name: Upload pytest test results uses: actions/upload-artifact@master diff --git a/CHANGELOG.md b/CHANGELOG.md index 4d28abc99a..fea424b9d0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -31,6 +31,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Added Linear Regression - Added Moco2g - Added simclr +- Added RL module - Added Loggers - Added Transforms - Added Tiny Datasets @@ -42,12 +43,16 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed +- Device is no longer set in the DQN model init +- Moved RL loss function to the losses module +- Moved rl.common.experience to datamodules - train_batch function to VPG model to generate batch of data at each step (POC) - Experience source no longer gets initialized with a device, instead the device is passed at each step() - Refactored ExperienceSource classes to be handle multiple environments. ### Removed +- Removed N-Step DQN as the latest version of the DQN supports N-Step by setting the `n_step` arg to n - Deprecated common.experience ### Fixed diff --git a/docs/source/index.rst b/docs/source/index.rst index 001990c191..e24bbeb38f 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -57,6 +57,7 @@ PyTorch-Lightning-Bolts documentation classic_ml convolutional gans + reinforce_learn self_supervised_models .. toctree:: diff --git a/docs/source/losses.rst b/docs/source/losses.rst index 3f2b120fee..44b401dfcc 100644 --- a/docs/source/losses.rst +++ b/docs/source/losses.rst @@ -10,3 +10,33 @@ This package lists common losses across research domains Your Loss --------- We're cleaning up many of our losses, but in the meantime, submit a PR to add your loss here! + +------------- + +Reinforcement Learning +====================== +These are common losses used in RL. + +--------------- + +DQN Loss +-------- + +.. autofunction:: pl_bolts.losses.rl.dqn_loss + :noindex: + +--------------- + +Double DQN Loss +--------------- + +.. autofunction:: pl_bolts.losses.rl.double_dqn_loss + :noindex: + +--------------- + +Per DQN Loss +------------ + +.. autofunction:: pl_bolts.losses.rl.per_dqn_loss + :noindex: diff --git a/docs/source/reinforce_learn.rst b/docs/source/reinforce_learn.rst new file mode 100644 index 0000000000..827feb395a --- /dev/null +++ b/docs/source/reinforce_learn.rst @@ -0,0 +1,669 @@ +Reinforcement Learning +====================== + +This module is a collection of common RL approaches implemented in Lightning. + +----------------- + +Module authors +-------------- + +Contributions by: `Donal Byrne `_ + +- DQN +- Double DQN +- Dueling DQN +- Noisy DQN +- NStep DQN +- Prioritized Experience Replay DQN +- Reinforce +- Vanilla Policy Gradient + +------------ + +.. note:: + RL models currently only support CPU and single GPU training with `distributed_backend=dp`. + Full GPU support will be added in later updates. + + +DQN Models +---------- + +The following models are based on DQN. DQN uses Value based learning where it is deciding what action to take based +on the models current learned value (V), or the state action value (Q) of the current state. These Values are defined +as the discounted total reward of the agents state or state action pair. + +--------------- + +Deep-Q-Network (DQN) +^^^^^^^^^^^^^^^^^^^^ + +DQN model introduced in `Playing Atari with Deep Reinforcement Learning `_. +Paper authors: Volodymyr Mnih, Koray Kavukcuoglu, David Silver, Alex Graves, Ioannis Antonoglou, Daan Wierstra, Martin Riedmiller. + +Original implementation by: `Donal Byrne `_ + +The DQN was introduced in `Playing Atari with Deep Reinforcement Learning `_ by +researchers at DeepMind. This took the concept of tabular Q learning and scaled it to much larger problems by +apporximating the Q function using a deep neural network. + +The goal behind DQN was to take the simple control method of Q learning and scale it up in order to solve complicated + tasks. As well as this, the method needed to be stable. The DQN solves these issues with the following additions. + +**Approximated Q Function** + +Storing Q values in a table works well in theory, but is completely unscalable. Instead, the authors apporximate the +Q function using a deep neural network. This allows the DQN to be used for much more complicated tasks + +**Replay Buffer** + +Similar to supervised learning, the DQN learns on randomly sampled batches of previous data stored in an +Experience Replay Buffer. The 'target' is calculated using the Bellman equation + +.. math:: + + Q(s,a)<-(r+{\gamma}\max_{a'{\in}A}Q(s',a'))^2 + +and then we optimize using SGD just like a standard supervised learning problem. + +.. math:: + + L=(Q(s,a)-(r+{\gamma}\max_{a'{\in}A}Q(s',a'))^2 + +DQN Results +~~~~~~~~~~~ + +**DQN: Pong** + +.. image:: _images/rl_benchmark/pong_dqn_baseline_results.jpg + :width: 800 + :alt: DQN Baseline Results + +Example:: + + from pl_bolts.models.rl import DQN + dqn = DQN("PongNoFrameskip-v4") + trainer = Trainer() + trainer.fit(dqn) + +.. autoclass:: pl_bolts.models.rl.dqn_model.DQN + :noindex: + +--------------- + +Double DQN +^^^^^^^^^^ + +Double DQN model introduced in `Deep Reinforcement Learning with Double Q-learning `_ +Paper authors: Hado van Hasselt, Arthur Guez, David Silver + +Original implementation by: `Donal Byrne `_ + +The original DQN tends to overestimate Q values during the Bellman update, leading to instability and is harmful to +training. This is due to the max operation in the Bellman equation. + +We are constantly taking the max of our agents estimates +during our update. This may seem reasonable, if we could trust these estimates. However during the early stages of +training, the estimates for these values will be off center and can lead to instability in training until +our estimates become more reliable + +The Double DQN fixes this overestimation by choosing actions for the next state using the main trained network +but uses the values of these actions from the more stable target network. So we are still going to take the greedy +action, but the value will be less "optimisitc" because it is chosen by the target network. + +**DQN expected return** + + +.. math:: + + Q(s_t, a_t) = r_t + \gamma * \max_{Q'}(S_{t+1}, a) + +**Double DQN expected return** + +.. math:: + + Q(s_t, a_t) = r_t + \gamma * \max{Q'}(S_{t+1}, \arg\max_Q(S_{t+1}, a)) + +Double DQN Results +~~~~~~~~~~~~~~~~~~ + +**Double DQN: Pong** + +.. image:: _images/rl_benchmark/pong_double_dqn_baseline_results.jpg + :width: 800 + :alt: Double DQN Result + +**DQN vs Double DQN: Pong** + +orange: DQN + +blue: Double DQN + +.. image:: _images/rl_benchmark/dqn_ddqn_comparison.jpg + :width: 800 + :alt: Double DQN Comparison Result + +Example:: + + from pl_bolts.models.rl import DoubleDQN + ddqn = DoubleDQN("PongNoFrameskip-v4") + trainer = Trainer() + trainer.fit(ddqn) + +.. autoclass:: pl_bolts.models.rl.double_dqn_model.DoubleDQN + :noindex: + +--------------- + +Dueling DQN +^^^^^^^^^^^ + +Dueling DQN model introduced in `Dueling Network Architectures for Deep Reinforcement Learning `_ +Paper authors: Ziyu Wang, Tom Schaul, Matteo Hessel, Hado van Hasselt, Marc Lanctot, Nando de Freitas + +Original implementation by: `Donal Byrne `_ + +The Q value that we are trying to approximate can be divided into two parts, the value state V(s) and the 'advantage' +of actions in that state A(s, a). Instead of having one full network estimate the entire Q value, Dueling DQN uses two +estimator heads in order to separate the estimation of the two parts. + +The value is the same as in value iteration. It is the discounted expected reward achieved from state s. Think of the +value as the 'base reward' from being in state s. + +The advantage tells us how much 'extra' reward we get from taking action a while in state s. The advantage bridges the +gap between Q(s, a) and V(s) as Q(s, a) = V(s) + A(s, a). + +In the paper [Dueling Network Architectures for Deep Reinforcement Learning](https://arxiv.org/abs/1511.06581) the +network uses two heads, one outputs the value state and the other outputs the advantage. This leads to better +training stability, faster convergence and overall better results. The V head outputs a single scalar +(the state value), while the advantage head outputs a tensor equal to the size of the action space, containing +an advantage value for each action in state s. + +Changing the network architecture is not enough, we also need to ensure that the advantage mean is 0. This is done +by subtracting the mean advantage from the Q value. This essentially pulls the mean advantage to 0. + +.. math:: + + Q(s, a) = V(s) + A(s, a) - 1/N * \sum_k(A(s, k) + +Dueling DQN Benefits +~~~~~~~~~~~~~~~~~~~~ + +- Ability to efficiently learn the state value function. In the dueling network, every Q update also updates the Value + stream, where as in DQN only the value of the chosen action is updated. This provides a better approximation of the + values +- The differences between total Q values for a given state are quite small in relation to the magnitude of Q. The + difference in the Q values between the best action and the second best action can be very small, while the average + state value can be much larger. The differences in scale can introduce noise, which may lead to the greedy policy + switching the priority of these actions. The seperate estimators for state value and advantage makes the Dueling + DQN robust to this type of scenario + +Dueling DQN Results +~~~~~~~~~~~~~~~~~~~ + +The results below a noticeable improvement from the original DQN network. + + +**Dueling DQN baseline: Pong** + +Similar to the results of the DQN baseline, the agent has a period where the number of steps per episodes increase as +it begins to hold its own against the heuristic oppoent, but then the steps per episode quickly begins to drop +as it gets better and starts to beat its opponent faster and faster. There is a noticable point at step ~250k +where the agent goes from losing to winning. + +As you can see by the total rewards, the dueling network's training progression is very stable and continues to trend +upward until it finally plateus. + +.. image:: _images/rl_benchmark/pong_dueling_dqn_results.jpg + :width: 800 + :alt: Dueling DQN Result + +**DQN vs Dueling DQN: Pong** + +In comparison to the base DQN, we see that the Dueling network's training is much more stable and is able to reach a +score in the high teens faster than the DQN agent. Even though the Dueling network is more stable and out performs DQN +early in training, by the end of training the two networks end up at the same point. + +This could very well be due to the simplicity of the Pong environment. + + - Orange: DQN + - Red: Dueling DQN + +.. image:: _images/rl_benchmark/pong_dueling_dqn_comparison.jpg + :width: 800 + :alt: Dueling DQN Comparison Result + +Example:: + + from pl_bolts.models.rl import DuelingDQN + dueling_dqn = DuelingDQN("PongNoFrameskip-v4") + trainer = Trainer() + trainer.fit(dueling_dqn) + +.. autoclass:: pl_bolts.models.rl.dueling_dqn_model.DuelingDQN + :noindex: + +-------------- + +Noisy DQN +^^^^^^^^^ + +Noisy DQN model introduced in `Noisy Networks for Exploration `_ +Paper authors: Meire Fortunato, Mohammad Gheshlaghi Azar, Bilal Piot, Jacob Menick, Ian Osband, Alex Graves, +Vlad Mnih, Remi Munos, Demis Hassabis, Olivier Pietquin, Charles Blundell, Shane Legg + +Original implementation by: `Donal Byrne `_ + +Up until now the DQN agent uses a seperate exploration policy, generally epsilon-greedy where start and end values +are set for its exploration. [Noisy Networks For Exploration](https://arxiv.org/abs/1706.10295) introduces +a new exploration strategy by adding noise parameters to the weightsof the fully connect layers which get updated +during backpropagation of the network. The noise parameters drive +the exploration of the network instead of simply taking random actions more frequently at the start of training and +less frequently towards the end.The of authors of +propose two ways of doing this. + +During the optimization step a new set of noisy parameters are sampled. During training the agent acts according to +the fixed set of parameters. At the next optimization step, the parameters are updated with a new sample. This ensures +the agent always acts based on the parameters that are drawn from the current noise +distribution. + +The authors propose two methods of injecting noise to the network. + +1) Independent Gaussian Noise: This injects noise per weight. For each weight a random value is taken from + the distribution. Noise parameters are stored inside the layer and are updated during backpropagation. + The output of the layer is calculated as normal. +2) Factorized Gaussian Noise: This injects nosier per input/ouput. In order to minimize the number of random values + this method stores two random vectors, one with the size of the input and the other with the size of the output. + Using these two vectors, a random matrix is generated for the layer by calculating the outer products of the vector + + +Noisy DQN Benefits +~~~~~~~~~~~~~~~~~~ + +- Improved exploration function. Instead of just performing completely random actions, we add decreasing amount of noise + and uncertainty to our policy allowing to explore while still utilising its policy +- The fact that this method is automatically tuned means that we do not have to tune hyper parameters for + epsilon-greedy! + +.. note:: + for now I have just implemented the Independant Gaussian as it has been reported there isn't much difference + in results for these benchmark environments. + +In order to update the basic DQN to a Noisy DQN we need to do the following + +Noisy DQN Results +~~~~~~~~~~~~~~~~~ + +The results below improved stability and faster performance growth. + +**Noisy DQN baseline: Pong** + + +Similar to the other improvements, the average score of the agent reaches positive numbers around the 250k mark and +steadily increases till convergence. + +.. image:: _images/rl_benchmark/pong_noisy_dqn_results.jpg + :width: 800 + :alt: Noisy DQN Result + +**DQN vs Dueling DQN: Pong** + +In comparison to the base DQN, the Noisy DQN is more stable and is able to converge on an optimal policy much faster +than the original. It seems that the replacement of the epsilon-greedy strategy with network noise provides a better +form of exploration. + +- Orange: DQN +- Red: Noisy DQN + +.. image:: _images/rl_benchmark/pong_noisy_dqn_comparison.jpg + :width: 800 + :alt: Noisy DQN Comparison Result + +Example:: + + from pl_bolts.models.rl import NoisyDQN + noisy_dqn = NoisyDQN("PongNoFrameskip-v4") + trainer = Trainer() + trainer.fit(noisy_dqn) + +.. autoclass:: pl_bolts.models.rl.noisy_dqn_model.NoisyDQN + :noindex: + +-------------- + +N-Step DQN +^^^^^^^^^^ + +N-Step DQN model introduced in `Learning to Predict by the Methods of Temporal Differences `_ +Paper authors: Richard S. Sutton + +Original implementation by: `Donal Byrne `_ + +N Step DQN was introduced in `Learning to Predict by the Methods of Temporal Differences +`_. +This method improves upon the original DQN by updating our Q values with the expected reward from multiple steps in the +future as opposed to the expected reward from the immediate next state. When getting the Q values for a state action +pair using a single step which looks like this + +.. math:: + + Q(s_t,a_t)=r_t+{\gamma}\max_aQ(s_{t+1},a_{t+1}) + +but because the Q function is recursive we can continue to roll this out into multiple steps, looking at the expected + return for each step into the future. + +.. math:: + + Q(s_t,a_t)=r_t+{\gamma}r_{t+1}+{\gamma}^2\max_{a'}Q(s_{t+2},a') + +The above example shows a 2-Step look ahead, but this could be rolled out to the end of the episode, which is just +Monte Carlo learning. Although we could just do a monte carlo update and look forward to the end of the episode, it +wouldn't be a good idea. Every time we take another step into the future, we are basing our approximation off our +current policy. For a large portion of training, our policy is going to be less than optimal. For example, at the start +of training, our policy will be in a state of high exploration, and will be little better than random. + +.. note:: + For each rollout step you must scale the discount factor accordingly by the number of steps. As you can see from the + equation above, the second gamma value is to the power of 2. If we rolled this out one step further, we would use + gamma to the power of 3 and so. + +So if we are aproximating future rewards off a bad policy, chances are those approximations are going to be pretty +bad and every time we unroll our update equation, the worse it will get. The fact that we are using an off policy +method like DQN with a large replay buffer will make this even worse, as there is a high chance that we will be +training on experiences using an old policy that was worse than our current policy. + +So we need to strike a balance between looking far enough ahead to improve the convergence of our agent, but not so far + that are updates become unstable. In general, small values of 2-4 work best. + +N-Step Benefits +~~~~~~~~~~~~~~~ + +- Multi-Step learning is capable of learning faster than typical 1 step learning methods. +- Note that this method introduces a new hyperparameter n. Although n=4 is generally a good starting point and provides + good results across the board. + +N-Step Results +~~~~~~~~~~~~~~ + +As expected, the N-Step DQN converges much faster than the standard DQN, however it also adds more instability to the +loss of the agent. This can be seen in the following experiments. + + +**N-Step DQN: Pong** + +The N-Step DQN shows the greatest increase in performance with respect to the other DQN variations. +After less than 150k steps the agent begins to consistently win games and achieves the top score after ~170K steps. +This is reflected in the sharp peak of the total episode steps and of course, the total episode rewards. + +.. image:: _images/rl_benchmark/pong_nstep_dqn_1.jpg + :width: 800 + :alt: N-Step DQN Result + +**DQN vs N-Step DQN: Pong** + +This improvement is shown in stark contrast to the base DQN, which only begins to win games after 250k steps and +requires over twice as many steps (450k) as the N-Step agent to achieve the high score of 21. One important thing to +notice is the large increase in the loss of the N-Step agent. This is expected as the agent is building +its expected reward off approximations of the future states. The large the size of N, the greater the instability. +Previous literature, listed below, shows the best results for the Pong environment with an N step between 3-5. +For these experiments I opted with an N step of 4. + + +.. image:: _images/rl_benchmark/pong_nstep_dqn_2.jpg + :width: 800 + :alt: N-Step DQN Comparison Results + +Example:: + + from pl_bolts.models.rl import DQN + n_step_dqn = DQN("PongNoFrameskip-v4", n_steps=4) + trainer = Trainer() + trainer.fit(n_step_dqn) + +-------------- + +Prioritized Experience Replay DQN +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Double DQN model introduced in `Prioritized Experience Replay `_ +Paper authors: Tom Schaul, John Quan, Ioannis Antonoglou, David Silver + +Original implementation by: `Donal Byrne `_ + +The standard DQN uses a buffer to break up the correlation between experiences and uniform random samples for each +batch. Instead of just randomly sampling from the buffer prioritized experience replay (PER) prioritizes these samples +based on training loss. This concept was introduced in the paper +`Prioritized Experience Replay `__ + +Essentially we want to train more on the samples that sunrise the agent. + +The priority of each sample is defined below where + + +.. math:: + + P(i) = P^\alpha_i / \sum_k P_k^\alpha + + +where pi is the priority of the ith sample in the buffer and +𝛼 is the number that shows how much emphasis we give to the priority. If 𝛼 = 0 , our +sampling will become uniform as in the classic DQN method. Larger values for 𝛼 put +more stress on samples with higher priority + +Its important that new samples are set to the highest priority so that they are sampled soon. This however introduces +bias to new samples in our dataset. In order to compensate for this bias, the value of the weight is defined as + +.. math:: + + w_i=(N . P(i))^{-\beta} + +Where beta is a hyper parameter between 0-1. When beta is 1 the bias is fully compensated. However authors noted that +in practice it is better to start beta with a small value near 0 and slowly increase it to 1. + +PER Benefits +~~~~~~~~~~~~ + +- The benefits of this technique are that the agent sees more samples that it struggled with and gets more + chances to improve upon it. + +**Memory Buffer** + + +First step is to replace the standard experience replay buffer with the prioritized experience replay buffer. This +is pretty large (100+ lines) so I wont go through it here. There are two buffers implemented. The first is a naive +list based buffer found in memory.PERBuffer and the second is more efficient buffer using a Sum Tree datastructure. + +The list based version is simpler, but has a sample complexity of O(N). The Sum Tree in comparison has a complexity +of O(1) for sampling and O(logN) for updating priorities. + +**Update loss function** + +The next thing we do is to use the sample weights that we get from PER. Add the following code to the end of the +loss function. This applies the weights of our sample to the batch loss. Then we return the mean loss and weighted loss +for each datum, with the addition of a small epsilon value. + + +PER Results +~~~~~~~~~~~ + +The results below show improved stability and faster performance growth. + +**PER DQN: Pong** + +Similar to the other improvements, we see that PER improves the stability of the agents training and shows to converged +on an optimal policy faster. + +.. image:: _images/rl_benchmark/pong_per_dqn_baseline_v1_results.jpg + :width: 800 + :alt: PER DQN Results + +**DQN vs PER DQN: Pong** + +In comparison to the base DQN, the PER DQN does show improved stability and performance. As expected, the loss + of the PER DQN is siginificantly lower. This is the main objective of PER by focusing on experiences with high loss. + +It is important to note that loss is not the only metric we should be looking at. Although the agent may have very + low loss during training, it may still perform poorly due to lack of exploration. + +.. image:: _images/rl_benchmark/pong_per_dqn_baseline_v1_results_comp.jpg + :width: 800 + :alt: PER DQN Results + +- Orange: DQN +- Pink: PER DQN + +Example:: + + from pl_bolts.models.rl import PERDQN + per_dqn = PERDQN("PongNoFrameskip-v4") + trainer = Trainer() + trainer.fit(per_dqn) + +.. autoclass:: pl_bolts.models.rl.per_dqn_model.PERDQN + :noindex: + + +-------------- + +Policy Gradient Models +---------------------- +The following models are based on Policy Gradients. Unlike the Q learning models shown before, Policy based models +do not try and learn the specifc values of state or state action pairs. Instead it cuts out the middle man and +directly learns the policy distribution. In Policy Gradient models we update our network parameters in the direction +suggested by our policy gradient in order to find a policy that produces the highest results. + +Policy Gradient Key Points: + - Outputs a distribution of actions instead of discrete Q values + - Optimizes the policy directly, instead of indirectly through the optimization of Q values + - The policy distribution of actions allows the model to handle more complex action spaces, such as continuos actions + - The policy distribution introduces stochasticity, providing natural exploration to the model + - The policy distribution provides a more stable update as a change in weights will only change the total distribution + slightly, as opposed to changing weights based on the Q value of state S will change all Q values with similar states. + - Policy gradients tend to converge faste, however they are not as sample efficient and generally require more + interactions with the environment. + + +-------------- + +REINFORCE +^^^^^^^^^ + +REINFORCE model introduced in `Policy Gradient Methods For Reinforcement Learning With Function Approximation `_ +Paper authors: Richard S. Sutton, David McAllester, Satinder Singh, Yishay Mansour + +Original implementation by: `Donal Byrne `_ + +REINFORCE is one of the simplest forms of the Policy Gradient method of RL. This method uses a Monte Carlo rollout, +where its steps through entire episodes of the environment to build up trajectories computing the total rewards. The +algorithm is as follows: + +1. Initialize our network. +2. Play N full episodes saving the transitions through the environment. +3. For every step `t` in each episode `k` we calculate the discounted reward of the subsequent steps. + +.. math:: + + Q_{k,t} = \sum_{i=0}\gamma^i r_i + +4. Calculate the loss for all transitions. + +.. math:: + + L = - \sum_{k,t} Q_{k,t} \log(\pi(S_{k,t}, A_{k,t})) + +5. Perform SGD on the loss and repeat. + + +What this loss function is saying is simply that we want to take the log probability of action A at state S given +our policy (network output). This is then scaled by the discounted reward that we calculated in the previous step. +We then take the negative of our sum. This is because the loss is minimized during SGD, but we want to +maximize our policy. + +.. note:: + the current implementation does not actually wait for the batch episodes the complete every time as we pass in a + fixed batch size. For the time being we simply use a large batch size to accomodate this. This approach still works + well for simple tasks as it still manages to get an accurate Q value by using a large batch size, but it is not + as accurate or completely correct. This will be updated in a later version. + + +REINFORCE Benefits +~~~~~~~~~~~~~~~~~~~~~~~~ + +- Simple and straightforward + +- Computationally more efficient for simple tasks such as Cartpole than the Value Based methods. + +REINFORCE Results +~~~~~~~~~~~~~~~~~~~~~ + +Hyperparameters: + +- Batch Size: 800 +- Learning Rate: 0.01 +- Episodes Per Batch: 4 +- Gamma: 0.99 + +TODO: Add results graph + +Example:: + + from pl_bolts.models.rl import Reinforce + reinforce = Reinforce("CartPole-v0") + trainer = Trainer() + trainer.fit(reinforce) + +.. autoclass:: pl_bolts.models.rl.reinforce_model.Reinforce + :noindex: + +-------------- + +Vanilla Policy Gradient +^^^^^^^^^^^^^^^^^^^^^^^ + +Vanilla Policy Gradient model introduced in `Policy Gradient Methods For Reinforcement Learning With Function Approximation `_ +Paper authors: Richard S. Sutton, David McAllester, Satinder Singh, Yishay Mansour + +Original implementation by: `Donal Byrne `_ + +Vanilla Policy Gradient (VPG) expands upon the REINFORCE algorithm and improves some of its major issues. The major +issue with REINFORCE is that it has high variance. This can be improved by subtracting a baseline value from the +Q values. For this implementation we use the average reward as our baseline. + +Although Policy Gradients are able to explore naturally due to the stochastic nature of the model, the agent can still +frequently be stuck in a local optima. In order to improve this, VPG adds an entropy term to improve exploration. + +.. math:: + + H(\pi) = - \sum \pi (a | s) \log \pi (a | s) + +To further control the amount of additional entropy in our model we scale the entropy term by a small beta value. The +scaled entropy is then subtracted from the policy loss. + +VPG Benefits +~~~~~~~~~~~~~~~ + +- Addition of the baseline reduces variance in the model + +- Improved exploration due to entropy bonus + +VPG Results +~~~~~~~~~~~~~~~~ + +Hyperparameters: + +- Batch Size: 8 +- Learning Rate: 0.001 +- N Steps: 10 +- N environments: 4 +- Entropy Beta: 0.01 +- Gamma: 0.99 + +Example:: + + from pl_bolts.models.rl import VanillaPolicyGradient + vpg = VanillaPolicyGradient("CartPole-v0") + trainer = Trainer() + trainer.fit(vpg) + +.. autoclass:: pl_bolts.models.rl.vanilla_policy_gradient_model.VanillaPolicyGradient + :noindex: diff --git a/pl_bolts/datamodules/__init__.py b/pl_bolts/datamodules/__init__.py index 1dd2e7c9aa..d3523ff41a 100644 --- a/pl_bolts/datamodules/__init__.py +++ b/pl_bolts/datamodules/__init__.py @@ -7,6 +7,11 @@ CIFAR10DataModule, TinyCIFAR10DataModule, ) + from pl_bolts.datamodules.experience_source import ( + ExperienceSourceDataset, + ExperienceSource, + DiscountedExperienceSource, + ) from pl_bolts.datamodules.fashion_mnist_datamodule import FashionMNISTDataModule from pl_bolts.datamodules.imagenet_datamodule import ImagenetDataModule from pl_bolts.datamodules.mnist_datamodule import MNISTDataModule diff --git a/pl_bolts/datamodules/cifar10_dataset.py b/pl_bolts/datamodules/cifar10_dataset.py index 63d2f1f744..5ddb44ab36 100644 --- a/pl_bolts/datamodules/cifar10_dataset.py +++ b/pl_bolts/datamodules/cifar10_dataset.py @@ -87,6 +87,9 @@ def __init__( self.train = train # training set or test set self.transform = transform + if not _PIL_AVAILABLE: + raise ImportError('You want to use PIL.Image for loading but it is not installed yet.') + os.makedirs(self.cached_folder_path, exist_ok=True) self.prepare_data(download) diff --git a/pl_bolts/datamodules/experience_source.py b/pl_bolts/datamodules/experience_source.py new file mode 100644 index 0000000000..6a4671234f --- /dev/null +++ b/pl_bolts/datamodules/experience_source.py @@ -0,0 +1,278 @@ +""" +Datamodules for RL models that rely on experiences generated during training +Based on implementations found here: https://github.com/Shmuma/ptan/blob/master/ptan/experience.py +""" +from abc import ABC +from collections import deque, namedtuple +from typing import Iterable, Callable, Tuple, List + +import torch +from gym import Env +from torch.utils.data import IterableDataset + +# Datasets + +Experience = namedtuple( + "Experience", field_names=["state", "action", "reward", "done", "new_state"] +) + + +class ExperienceSourceDataset(IterableDataset): + """ + Basic experience source dataset. Takes a generate_batch function that returns an iterator. + The logic for the experience source and how the batch is generated is defined the Lightning model itself + """ + + def __init__(self, generate_batch: Callable): + self.generate_batch = generate_batch + + def __iter__(self) -> Iterable: + iterator = self.generate_batch() + return iterator + + +# Experience Sources +class BaseExperienceSource(ABC): + """ + Simplest form of the experience source + Args: + env: Environment that is being used + agent: Agent being used to make decisions + """ + + def __init__(self, env, agent) -> None: + self.env = env + self.agent = agent + + def runner(self) -> Experience: + """Iterable method that yields steps from the experience source""" + raise NotImplementedError("ExperienceSource has no stepper method implemented") + + +class ExperienceSource(BaseExperienceSource): + """ + Experience source class handling single and multiple environment steps + Args: + env: Environment that is being used + agent: Agent being used to make decisions + n_steps: Number of steps to return from each environment at once + """ + + def __init__(self, env, agent, n_steps: int = 1) -> None: + super().__init__(env, agent) + + self.pool = env if isinstance(env, (list, tuple)) else [env] + self.exp_history_queue = deque() + + self.n_steps = n_steps + self.total_steps = [] + self.states = [] + self.histories = [] + self.cur_rewards = [] + self.cur_steps = [] + self.iter_idx = 0 + + self._total_rewards = [] + + self.init_environments() + + def runner(self, device: torch.device) -> Tuple[Experience]: + """Experience Source iterator yielding Tuple of experiences for n_steps. These come from the pool + of environments provided by the user. + Args: + device: current device to be used for executing experience steps + Returns: + Tuple of Experiences + """ + while True: + # get actions for all envs + actions = self.env_actions(device) + + # step through each env + for env_idx, (env, action) in enumerate(zip(self.pool, actions)): + + exp = self.env_step(env_idx, env, action) + history = self.histories[env_idx] + history.append(exp) + self.states[env_idx] = exp.new_state + + self.update_history_queue(env_idx, exp, history) + + # Yield all accumulated history tuples to model + while self.exp_history_queue: + yield self.exp_history_queue.popleft() + + self.iter_idx += 1 + + def update_history_queue(self, env_idx, exp, history) -> None: + """ + Updates the experience history queue with the lastest experiences. In the event of an experience step is in + the done state, the history will be incrementally appended to the queue, removing the tail of the history + each time. + Args: + env_idx: index of the environment + exp: the current experience + history: history of experience steps for this environment + """ + # If there is a full history of step, append history to queue + if len(history) == self.n_steps: + self.exp_history_queue.append(tuple(history)) + + if exp.done: + if 0 < len(history) < self.n_steps: + self.exp_history_queue.append(tuple(history)) + + # generate tail of history, incrementally append history to queue + while len(history) > 2: + history.popleft() + self.exp_history_queue.append(tuple(history)) + + # when there are only 2 experiences left in the history, + # append to the queue then update the env stats and reset the environment + if len(history) > 1: + self.update_env_stats(env_idx) + + history.popleft() + self.exp_history_queue.append(tuple(history)) + + # Clear that last tail in the history once all others have been added to the queue + history.clear() + + def init_environments(self) -> None: + """ + For each environment in the pool setups lists for tracking history of size n, state, current reward and + current step + """ + for env in self.pool: + self.states.append(env.reset()) + self.histories.append(deque(maxlen=self.n_steps)) + self.cur_rewards.append(0.0) + self.cur_steps.append(0) + + def env_actions(self, device) -> List[List[int]]: + """ + For each environment in the pool, get the correct action + Returns: + List of actions for each env, with size (num_envs, action_size) + """ + actions = [] + states_actions = self.agent(self.states, device) + + assert len(self.states) == len(states_actions) + + for idx, action in enumerate(states_actions): + actions.append(action if isinstance(action, list) else [action]) + + return actions + + def env_step(self, env_idx: int, env: Env, action: List[int]) -> Experience: + """ + Carries out a step through the given environment using the given action + Args: + env_idx: index of the current environment + env: env at index env_idx + action: action for this environment step + Returns: + Experience tuple + """ + next_state, r, is_done, _ = env.step(action[0]) + + self.cur_rewards[env_idx] += r + self.cur_steps[env_idx] += 1 + + exp = Experience(state=self.states[env_idx], action=action[0], reward=r, done=is_done, new_state=next_state) + + return exp + + def update_env_stats(self, env_idx: int) -> None: + """ + To be called at the end of the history tail generation during the termination state. Updates the stats + tracked for all environments + Args: + env_idx: index of the environment used to update stats + """ + self._total_rewards.append(self.cur_rewards[env_idx]) + self.total_steps.append(self.cur_steps[env_idx]) + self.cur_rewards[env_idx] = 0 + self.cur_steps[env_idx] = 0 + self.states[env_idx] = self.pool[env_idx].reset() + + def pop_total_rewards(self) -> List[float]: + """ + Returns the list of the current total rewards collected + Returns: + list of total rewards for all completed episodes for each environment since last pop + """ + rewards = self._total_rewards + + if rewards: + self._total_rewards = [] + self.total_steps = [] + + return rewards + + def pop_rewards_steps(self): + """ + Returns the list of the current total rewards and steps collected + Returns: + list of total rewards and steps for all completed episodes for each environment since last pop + """ + res = list(zip(self._total_rewards, self.total_steps)) + if res: + self._total_rewards, self.total_steps = [], [] + return res + + +class DiscountedExperienceSource(ExperienceSource): + """Outputs experiences with a discounted reward over N steps""" + + def __init__(self, env: Env, agent, n_steps: int = 1, gamma: float = 0.99): + super().__init__(env, agent, (n_steps + 1)) + self.gamma = gamma + self.steps = n_steps + + def runner(self, device: torch.device) -> Experience: + """ + Iterates through experience tuple and calculate discounted experience + Args: + device: current device to be used for executing experience steps + Yields: + Discounted Experience + """ + for experiences in super().runner(device): + last_exp_state, tail_experiences = self.split_head_tail_exp(experiences) + + total_reward = self.discount_rewards(tail_experiences) + + yield Experience(state=experiences[0].state, action=experiences[0].action, + reward=total_reward, done=experiences[0].done, new_state=last_exp_state) + + def split_head_tail_exp(self, experiences: Tuple[Experience]) -> Tuple[List, Tuple[Experience]]: + """ + Takes in a tuple of experiences and returns the last state and tail experiences based on + if the last state is the end of an episode + Args: + experiences: Tuple of N Experience + Returns: + last state (Array or None) and remaining Experience + """ + if experiences[-1].done and len(experiences) <= self.steps: + last_exp_state = experiences[-1].new_state + tail_experiences = experiences + else: + last_exp_state = experiences[-1].state + tail_experiences = experiences[:-1] + return last_exp_state, tail_experiences + + def discount_rewards(self, experiences: Tuple[Experience]) -> float: + """ + Calculates the discounted reward over N experiences + Args: + experiences: Tuple of Experience + Returns: + total discounted reward + """ + total_reward = 0.0 + for exp in reversed(experiences): + total_reward = (self.gamma * total_reward) + exp.reward + return total_reward diff --git a/pl_bolts/losses/rl.py b/pl_bolts/losses/rl.py new file mode 100644 index 0000000000..a4a974f7c6 --- /dev/null +++ b/pl_bolts/losses/rl.py @@ -0,0 +1,118 @@ +""" +Loss functions for the RL models +""" + +from typing import Tuple, List + +import numpy as np +import torch +from torch import nn + + +def dqn_loss(batch: Tuple[torch.Tensor, torch.Tensor], net: nn.Module, + target_net: nn.Module, gamma: float = 0.99) -> torch.Tensor: + """ + Calculates the mse loss using a mini batch from the replay buffer + Args: + batch: current mini batch of replay data + net: main training network + target_net: target network of the main training network + gamma: discount factor + Returns: + loss + """ + states, actions, rewards, dones, next_states = batch + + actions = actions.long().squeeze(-1) + + state_action_values = ( + net(states).gather(1, actions.unsqueeze(-1)).squeeze(-1) + ) + + with torch.no_grad(): + next_state_values = target_net(next_states).max(1)[0] + next_state_values[dones] = 0.0 + next_state_values = next_state_values.detach() + + expected_state_action_values = next_state_values * gamma + rewards + + return nn.MSELoss()(state_action_values, expected_state_action_values) + + +def double_dqn_loss(batch: Tuple[torch.Tensor, torch.Tensor], net: nn.Module, + target_net: nn.Module, gamma: float = 0.99) -> torch.Tensor: + """ + Calculates the mse loss using a mini batch from the replay buffer. This uses an improvement to the original + DQN loss by using the double dqn. This is shown by using the actions of the train network to pick the + value from the target network. This code is heavily commented in order to explain the process clearly + Args: + batch: current mini batch of replay data + net: main training network + target_net: target network of the main training network + gamma: discount factor + Returns: + loss + """ + states, actions, rewards, dones, next_states = batch # batch of experiences, batch_size = 16 + + actions = actions.long().squeeze(-1) + + state_action_values = ( + net(states).gather(1, actions.unsqueeze(-1)).squeeze(-1) + ) + + # dont want to mess with gradients when using the target network + with torch.no_grad(): + next_outputs = net(next_states) # [16, 2], [batch, action_space] + + next_state_acts = next_outputs.max(1)[1].unsqueeze( + -1 + ) # take action at the index with the highest value + next_tgt_out = target_net(next_states) + + # Take the value of the action chosen by the train network + next_state_values = next_tgt_out.gather(1, next_state_acts).squeeze(-1) + next_state_values[dones] = 0.0 # any steps flagged as done get a 0 value + next_state_values = ( + next_state_values.detach() + ) # remove values from the graph, no grads needed + + # calc expected discounted return of next_state_values + expected_state_action_values = next_state_values * gamma + rewards + + # Standard MSE loss between the state action values of the current state and the + # expected state action values of the next state + return nn.MSELoss()(state_action_values, expected_state_action_values) + + +def per_dqn_loss(batch: Tuple[torch.Tensor, torch.Tensor], batch_weights: List, net: nn.Module, + target_net: nn.Module, gamma: float = 0.99) -> Tuple[torch.Tensor, np.ndarray]: + """ + Calculates the mse loss with the priority weights of the batch from the PER buffer + Args: + batch: current mini batch of replay data + batch_weights: how each of these samples are weighted in terms of priority + net: main training network + target_net: target network of the main training network + gamma: discount factor + Returns: + loss and batch_weights + """ + states, actions, rewards, dones, next_states = batch + + actions = actions.long() + + batch_weights = torch.tensor(batch_weights) + + actions_v = actions.unsqueeze(-1) + outputs = net(states) + state_action_vals = outputs.gather(1, actions_v) + state_action_vals = state_action_vals.squeeze(-1) + + with torch.no_grad(): + next_s_vals = target_net(next_states).max(1)[0] + next_s_vals[dones] = 0.0 + exp_sa_vals = next_s_vals.detach() * gamma + rewards + loss = (state_action_vals - exp_sa_vals) ** 2 + losses_v = batch_weights * loss + return losses_v.mean(), (losses_v + 1e-5).data.cpu().numpy() diff --git a/pl_bolts/models/rl/__init__.py b/pl_bolts/models/rl/__init__.py new file mode 100644 index 0000000000..cec3f871c8 --- /dev/null +++ b/pl_bolts/models/rl/__init__.py @@ -0,0 +1,10 @@ +try: + from pl_bolts.models.rl.double_dqn_model import DoubleDQN + from pl_bolts.models.rl.dqn_model import DQN + from pl_bolts.models.rl.dueling_dqn_model import DuelingDQN + from pl_bolts.models.rl.noisy_dqn_model import NoisyDQN + from pl_bolts.models.rl.per_dqn_model import PERDQN + from pl_bolts.models.rl.reinforce_model import Reinforce + from pl_bolts.models.rl.vanilla_policy_gradient_model import VanillaPolicyGradient +except ModuleNotFoundError: + pass diff --git a/pl_bolts/models/rl/common/__init__.py b/pl_bolts/models/rl/common/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/pl_bolts/models/rl/common/agents.py b/pl_bolts/models/rl/common/agents.py new file mode 100644 index 0000000000..92c5fbb8fa --- /dev/null +++ b/pl_bolts/models/rl/common/agents.py @@ -0,0 +1,131 @@ +""" +Agent module containing classes for Agent logic +Based on the implementations found here: https://github.com/Shmuma/ptan/blob/master/ptan/agent.py +""" +from abc import ABC +from typing import List + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + + +class Agent(ABC): + """Basic agent that always returns 0""" + + def __init__(self, net: nn.Module): + self.net = net + + def __call__(self, state: torch.Tensor, device: str, *args, **kwargs) -> List[int]: + """ + Using the given network, decide what action to carry + Args: + state: current state of the environment + device: device used for current batch + Returns: + action + """ + return [0] + + +class ValueAgent(Agent): + """Value based agent that returns an action based on the Q values from the network""" + + def __init__( + self, + net: nn.Module, + action_space: int, + eps_start: float = 1.0, + eps_end: float = 0.2, + eps_frames: float = 1000, + ): + super().__init__(net) + self.action_space = action_space + self.eps_start = eps_start + self.epsilon = eps_start + self.eps_end = eps_end + self.eps_frames = eps_frames + + @torch.no_grad() + def __call__(self, state: torch.Tensor, device: str) -> List[int]: + """ + Takes in the current state and returns the action based on the agents policy + Args: + state: current state of the environment + device: the device used for the current batch + Returns: + action defined by policy + """ + if not isinstance(state, list): + state = [state] + + if np.random.random() < self.epsilon: + action = self.get_random_action(state) + else: + action = self.get_action(state, device) + + return action + + def get_random_action(self, state: torch.Tensor) -> int: + """returns a random action""" + actions = [] + + for i in range(len(state)): + action = np.random.randint(0, self.action_space) + actions.append(action) + + return actions + + def get_action(self, state: torch.Tensor, device: torch.device): + """ + Returns the best action based on the Q values of the network + Args: + state: current state of the environment + device: the device used for the current batch + Returns: + action defined by Q values + """ + if not isinstance(state, torch.Tensor): + state = torch.tensor(state, device=device) + + q_values = self.net(state) + _, actions = torch.max(q_values, dim=1) + return actions.detach().cpu().numpy() + + def update_epsilon(self, step: int) -> None: + """ + Updates the epsilon value based on the current step + Args: + step: current global step + """ + self.epsilon = max(self.eps_end, self.eps_start - (step + 1) / self.eps_frames) + + +class PolicyAgent(Agent): + """Policy based agent that returns an action based on the networks policy""" + + @torch.no_grad() + def __call__(self, states: torch.Tensor, device: str) -> List[int]: + """ + Takes in the current state and returns the action based on the agents policy + Args: + states: current state of the environment + device: the device used for the current batch + Returns: + action defined by policy + """ + if not isinstance(states, list): + states = [states] + + if not isinstance(states, torch.Tensor): + states = torch.tensor(states, device=device) + + # get the logits and pass through softmax for probability distribution + probabilities = F.softmax(self.net(states)).squeeze(dim=-1) + prob_np = probabilities.data.cpu().numpy() + + # take the numpy values and randomly select action based on prob distribution + actions = [np.random.choice(len(prob), p=prob) for prob in prob_np] + + return actions diff --git a/pl_bolts/models/rl/common/cli.py b/pl_bolts/models/rl/common/cli.py new file mode 100644 index 0000000000..a663c8acd8 --- /dev/null +++ b/pl_bolts/models/rl/common/cli.py @@ -0,0 +1,34 @@ +"""Contains generic arguments used for all models""" + +import argparse + + +def add_base_args(parent) -> argparse.ArgumentParser: + """ + Adds arguments for DQN model + + Note: these params are fine tuned for Pong env + + Args: + parent + """ + arg_parser = argparse.ArgumentParser(parents=[parent]) + + arg_parser.add_argument("--algo", type=str, default="dqn", help="algorithm to use for training") + arg_parser.add_argument("--batch_size", type=int, default=32, help="size of the batches") + arg_parser.add_argument("--lr", type=float, default=1e-4, help="learning rate") + + arg_parser.add_argument("--env", type=str, required=True, help="gym environment tag") + arg_parser.add_argument("--gamma", type=float, default=0.99, help="discount factor") + + arg_parser.add_argument("--episode_length", type=int, default=500, help="max length of an episode") + arg_parser.add_argument("--max_episode_reward", type=int, default=18, help="max episode reward in the environment") + arg_parser.add_argument("--n_steps", type=int, default=4, help="how many steps to unroll for each update",) + arg_parser.add_argument("--seed", type=int, default=123, help="seed for training run") + arg_parser.add_argument("--epoch_len", type=int, default=1000, help="how many batches per epoch") + arg_parser.add_argument("--num_envs", type=int, default=1, help="number of environments to run at once") + arg_parser.add_argument("--avg_reward_len", type=int, default=100, + help="how many episodes to include in avg reward") + + arg_parser.add_argument("--seed", type=int, default=123, help="seed for training run") + return arg_parser diff --git a/pl_bolts/models/rl/common/gym_wrappers.py b/pl_bolts/models/rl/common/gym_wrappers.py new file mode 100644 index 0000000000..8f492a27c1 --- /dev/null +++ b/pl_bolts/models/rl/common/gym_wrappers.py @@ -0,0 +1,207 @@ +""" +Set of wrapper functions for gym environments taken from +https://github.com/Shmuma/ptan/blob/master/ptan/common/wrappers.py +""" +import collections +from warnings import warn + +import gym +import gym.spaces +import numpy as np +import torch +try: + import cv2 +except ModuleNotFoundError: + warn('You want to use `openCV` which is not installed yet,' # pragma: no-cover + ' install it with `pip install opencv-python`.') + _OPENCV_AVAILABLE = False +else: + _OPENCV_AVAILABLE = True + + +class ToTensor(gym.Wrapper): + """For environments where the user need to press FIRE for the game to start.""" + + def __init__(self, env=None): + super(ToTensor, self).__init__(env) + + def step(self, action): + """Take 1 step and cast to tensor""" + state, reward, done, info = self.env.step(action) + return torch.tensor(state), torch.tensor(reward), done, info + + def reset(self): + """reset the env and cast to tensor""" + return torch.tensor(self.env.reset()) + + +class FireResetEnv(gym.Wrapper): + """For environments where the user need to press FIRE for the game to start.""" + + def __init__(self, env=None): + super(FireResetEnv, self).__init__(env) + assert env.unwrapped.get_action_meanings()[1] == "FIRE" + assert len(env.unwrapped.get_action_meanings()) >= 3 + + def step(self, action): + """Take 1 step""" + return self.env.step(action) + + def reset(self): + """reset the env""" + self.env.reset() + obs, _, done, _ = self.env.step(1) + if done: + self.env.reset() + obs, _, done, _ = self.env.step(2) + if done: + self.env.reset() + return obs + + +class MaxAndSkipEnv(gym.Wrapper): + """Return only every `skip`-th frame""" + + def __init__(self, env=None, skip=4): + super(MaxAndSkipEnv, self).__init__(env) + # most recent raw observations (for max pooling across time steps) + self._obs_buffer = collections.deque(maxlen=2) + self._skip = skip + + def step(self, action): + """take 1 step""" + total_reward = 0.0 + done = None + for _ in range(self._skip): + obs, reward, done, info = self.env.step(action) + self._obs_buffer.append(obs) + total_reward += reward + if done: + break + max_frame = np.max(np.stack(self._obs_buffer), axis=0) + return max_frame, total_reward, done, info + + def reset(self): + """Clear past frame buffer and init. to first obs. from inner env.""" + self._obs_buffer.clear() + obs = self.env.reset() + self._obs_buffer.append(obs) + return obs + + +class ProcessFrame84(gym.ObservationWrapper): + """preprocessing images from env""" + + def __init__(self, env=None): + + if not _OPENCV_AVAILABLE: + raise ModuleNotFoundError('This class uses OpenCV which it is not installed yet.') + + super(ProcessFrame84, self).__init__(env) + self.observation_space = gym.spaces.Box( + low=0, high=255, shape=(84, 84, 1), dtype=np.uint8 + ) + + def observation(self, obs): + """preprocess the obs""" + return ProcessFrame84.process(obs) + + @staticmethod + def process(frame): + """image preprocessing, formats to 84x84""" + if frame.size == 210 * 160 * 3: + img = np.reshape(frame, [210, 160, 3]).astype(np.float32) + elif frame.size == 250 * 160 * 3: + img = np.reshape(frame, [250, 160, 3]).astype(np.float32) + else: + assert False, "Unknown resolution." + img = img[:, :, 0] * 0.299 + img[:, :, 1] * 0.587 + img[:, :, 2] * 0.114 + resized_screen = cv2.resize(img, (84, 110), interpolation=cv2.INTER_AREA) + x_t = resized_screen[18:102, :] + x_t = np.reshape(x_t, [84, 84, 1]) + return x_t.astype(np.uint8) + + +class ImageToPyTorch(gym.ObservationWrapper): + """converts image to pytorch format""" + + def __init__(self, env): + + if not _OPENCV_AVAILABLE: + raise ModuleNotFoundError('This class uses OpenCV which it is not installed yet.') + + super(ImageToPyTorch, self).__init__(env) + old_shape = self.observation_space.shape + new_shape = (old_shape[-1], old_shape[0], old_shape[1]) + self.observation_space = gym.spaces.Box( + low=0.0, high=1.0, shape=new_shape, dtype=np.float32 + ) + + @staticmethod + def observation(observation): + """convert observation""" + return np.moveaxis(observation, 2, 0) + + +class ScaledFloatFrame(gym.ObservationWrapper): + """scales the pixels""" + + @staticmethod + def observation(obs): + return np.array(obs).astype(np.float32) / 255.0 + + +class BufferWrapper(gym.ObservationWrapper): + """"Wrapper for image stacking""" + + def __init__(self, env, n_steps, dtype=np.float32): + super(BufferWrapper, self).__init__(env) + self.dtype = dtype + self.buffer = None + old_space = env.observation_space + self.observation_space = gym.spaces.Box( + old_space.low.repeat(n_steps, axis=0), + old_space.high.repeat(n_steps, axis=0), + dtype=dtype, + ) + + def reset(self): + """reset env""" + self.buffer = np.zeros_like(self.observation_space.low, dtype=self.dtype) + return self.observation(self.env.reset()) + + def observation(self, observation): + """convert observation""" + self.buffer[:-1] = self.buffer[1:] + self.buffer[-1] = observation + return self.buffer + + +class DataAugmentation(gym.ObservationWrapper): + """ + Carries out basic data augmentation on the env observations + - ToTensor + - GrayScale + - RandomCrop + """ + + def __init__(self, env=None): + super().__init__(env) + self.observation_space = gym.spaces.Box( + low=0, high=255, shape=(84, 84, 1), dtype=np.uint8 + ) + + def observation(self, obs): + """preprocess the obs""" + return ProcessFrame84.process(obs) + + +def make_environment(env_name): + """Convert environment with wrappers""" + env = gym.make(env_name) + env = MaxAndSkipEnv(env) + env = FireResetEnv(env) + env = ProcessFrame84(env) + env = ImageToPyTorch(env) + env = BufferWrapper(env, 4) + return ScaledFloatFrame(env) diff --git a/pl_bolts/models/rl/common/memory.py b/pl_bolts/models/rl/common/memory.py new file mode 100644 index 0000000000..0cd058ee43 --- /dev/null +++ b/pl_bolts/models/rl/common/memory.py @@ -0,0 +1,313 @@ +"""Series of memory buffers sued""" + +# Named tuple for storing experience steps gathered in training +import collections +from collections import deque, namedtuple +from typing import Tuple, List, Union + +import numpy as np + +Experience = namedtuple( + "Experience", field_names=["state", "action", "reward", "done", "new_state"] +) + + +class Buffer: + """ + Basic Buffer for storing a single experience at a time + Args: + capacity: size of the buffer + """ + + def __init__(self, capacity: int) -> None: + self.buffer = deque(maxlen=capacity) + + def __len__(self) -> None: + return len(self.buffer) + + def append(self, experience: Experience) -> None: + """ + Add experience to the buffer + Args: + experience: tuple (state, action, reward, done, new_state) + """ + self.buffer.append(experience) + + # pylint: disable=unused-argument + def sample(self, *args) -> Union[Tuple, List[Tuple]]: + """ + returns everything in the buffer so far it is then reset + Returns: + a batch of tuple np arrays of state, action, reward, done, next_state + """ + states, actions, rewards, dones, next_states = zip( + *[self.buffer[idx] for idx in range(self.__len__())] + ) + + self.buffer.clear() + + return ( + np.array(states), + np.array(actions), + np.array(rewards, dtype=np.float32), + np.array(dones, dtype=np.bool), + np.array(next_states), + ) + + +class ReplayBuffer(Buffer): + """ + Replay Buffer for storing past experiences allowing the agent to learn from them + """ + + def sample(self, batch_size: int) -> Tuple: + """ + Takes a sample of the buffer + Args: + batch_size: current batch_size + Returns: + a batch of tuple np arrays of state, action, reward, done, next_state + """ + + indices = np.random.choice(len(self.buffer), batch_size, replace=False) + states, actions, rewards, dones, next_states = zip( + *[self.buffer[idx] for idx in indices] + ) + + return ( + np.array(states), + np.array(actions), + np.array(rewards, dtype=np.float32), + np.array(dones, dtype=np.bool), + np.array(next_states), + ) + + +class MultiStepBuffer(ReplayBuffer): + """ + N Step Replay Buffer + + Args: + capacity: max number of experiences that will be stored in the buffer + n_steps: number of steps used for calculating discounted reward/experience + gamma: discount factor when calculating n_step discounted reward of the experience being stored in buffer + """ + + def __init__(self, capacity: int, n_steps: int = 1, gamma: float = 0.99) -> None: + super().__init__(capacity) + + self.n_steps = n_steps + self.gamma = gamma + self.history = deque(maxlen=self.n_steps) + self.exp_history_queue = deque() + + def append(self, exp: Experience) -> None: + """ + Add experience to the buffer + Args: + exp: tuple (state, action, reward, done, new_state) + """ + self.update_history_queue(exp) # add single step experience to history + while self.exp_history_queue: # go through all the n_steps that have been queued + experiences = self.exp_history_queue.popleft() # get the latest n_step experience from queue + + last_exp_state, tail_experiences = self.split_head_tail_exp(experiences) + + total_reward = self.discount_rewards(tail_experiences) + + n_step_exp = Experience(state=experiences[0].state, action=experiences[0].action, reward=total_reward, + done=experiences[0].done, new_state=last_exp_state) + + self.buffer.append(n_step_exp) # add n_step experience to buffer + + def update_history_queue(self, exp) -> None: + """ + Updates the experience history queue with the lastest experiences. In the event of an experience step is in + the done state, the history will be incrementally appended to the queue, removing the tail of the history + each time. + Args: + env_idx: index of the environment + exp: the current experience + history: history of experience steps for this environment + """ + self.history.append(exp) + + # If there is a full history of step, append history to queue + if len(self.history) == self.n_steps: + self.exp_history_queue.append(list(self.history)) + + if exp.done: + if 0 < len(self.history) < self.n_steps: + self.exp_history_queue.append(list(self.history)) + + # generate tail of history, incrementally append history to queue + while len(self.history) > 2: + self.history.popleft() + self.exp_history_queue.append(list(self.history)) + + # when there are only 2 experiences left in the history, + # append to the queue then update the env stats and reset the environment + if len(self.history) > 1: + self.history.popleft() + self.exp_history_queue.append(list(self.history)) + + # Clear that last tail in the history once all others have been added to the queue + self.history.clear() + + def split_head_tail_exp(self, experiences: Tuple[Experience]) -> Tuple[List, Tuple[Experience]]: + """ + Takes in a tuple of experiences and returns the last state and tail experiences based on + if the last state is the end of an episode + Args: + experiences: Tuple of N Experience + Returns: + last state (Array or None) and remaining Experience + """ + last_exp_state = experiences[-1].new_state + tail_experiences = experiences + + if experiences[-1].done and len(experiences) <= self.n_steps: + tail_experiences = experiences + + return last_exp_state, tail_experiences + + def discount_rewards(self, experiences: Tuple[Experience]) -> float: + """ + Calculates the discounted reward over N experiences + Args: + experiences: Tuple of Experience + Returns: + total discounted reward + """ + total_reward = 0.0 + for exp in reversed(experiences): + total_reward = (self.gamma * total_reward) + exp.reward + return total_reward + + +class MeanBuffer: + """ + Stores a deque of items and calculates the mean + """ + + def __init__(self, capacity): + self.capacity = capacity + self.deque = collections.deque(maxlen=capacity) + self.sum = 0.0 + + def add(self, val: float) -> None: + """Add to the buffer""" + if len(self.deque) == self.capacity: + self.sum -= self.deque[0] + self.deque.append(val) + self.sum += val + + def mean(self) -> float: + """Retrieve the mean""" + if not self.deque: + return 0.0 + return self.sum / len(self.deque) + + +class PERBuffer(ReplayBuffer): + """ + simple list based Prioritized Experience Replay Buffer + Based on implementation found here: + https://github.com/Shmuma/ptan/blob/master/ptan/experience.py#L371 + """ + + def __init__(self, buffer_size, prob_alpha=0.6, beta_start=0.4, beta_frames=100000): + super().__init__(capacity=buffer_size) + self.beta_start = beta_start + self.beta = beta_start + self.beta_frames = beta_frames + self.prob_alpha = prob_alpha + self.capacity = buffer_size + self.pos = 0 + self.buffer = [] + self.priorities = np.zeros((buffer_size,), dtype=np.float32) + + def update_beta(self, step) -> float: + """ + Update the beta value which accounts for the bias in the PER + Args: + step: current global step + Returns: + beta value for this indexed experience + """ + beta_val = self.beta_start + step * (1.0 - self.beta_start) / self.beta_frames + self.beta = min(1.0, beta_val) + + return self.beta + + def append(self, exp) -> None: + """ + Adds experiences from exp_source to the PER buffer + Args: + exp: experience tuple being added to the buffer + """ + # what is the max priority for new sample + max_prio = self.priorities.max() if self.buffer else 1.0 + + if len(self.buffer) < self.capacity: + self.buffer.append(exp) + else: + self.buffer[self.pos] = exp + + # the priority for the latest sample is set to max priority so it will be resampled soon + self.priorities[self.pos] = max_prio + + # update position, loop back if it reaches the end + self.pos = (self.pos + 1) % self.capacity + + def sample(self, batch_size=32) -> Tuple: + """ + Takes a prioritized sample from the buffer + Args: + batch_size: size of sample + Returns: + sample of experiences chosen with ranked probability + """ + # get list of priority rankings + if len(self.buffer) == self.capacity: + prios = self.priorities + else: + prios = self.priorities[: self.pos] + + # probability to the power of alpha to weight how important that probability it, 0 = normal distirbution + probs = prios ** self.prob_alpha + probs /= probs.sum() + + # choise sample of indices based on the priority prob distribution + indices = np.random.choice(len(self.buffer), batch_size, p=probs) + # samples = [self.buffer[idx] for idx in indices] + states, actions, rewards, dones, next_states = zip( + *[self.buffer[idx] for idx in indices] + ) + + samples = ( + np.array(states), + np.array(actions), + np.array(rewards, dtype=np.float32), + np.array(dones, dtype=np.bool), + np.array(next_states), + ) + total = len(self.buffer) + + # weight of each sample datum to compensate for the bias added in with prioritising samples + weights = (total * probs[indices]) ** (-self.beta) + weights /= weights.max() + + # return the samples, the indices chosen and the weight of each datum in the sample + return samples, indices, np.array(weights, dtype=np.float32) + + def update_priorities(self, batch_indices: List, batch_priorities: List) -> None: + """ + Update the priorities from the last batch, this should be called after the loss for this batch has been + calculated. + Args: + batch_indices: index of each datum in the batch + batch_priorities: priority of each datum in the batch + """ + for idx, prio in zip(batch_indices, batch_priorities): + self.priorities[idx] = prio diff --git a/pl_bolts/models/rl/common/networks.py b/pl_bolts/models/rl/common/networks.py new file mode 100644 index 0000000000..4776424d39 --- /dev/null +++ b/pl_bolts/models/rl/common/networks.py @@ -0,0 +1,317 @@ +"""Series of networks used +Based on implementations found here: +""" +import math +from typing import Tuple + +import numpy as np +import torch +from torch import Tensor +from torch import nn +from torch.nn import functional as F + + +class CNN(nn.Module): + """ + Simple MLP network + Args: + input_shape: observation shape of the environment + n_actions: number of discrete actions available in the environment + """ + + def __init__(self, input_shape, n_actions): + super(CNN, self).__init__() + + self.conv = nn.Sequential( + nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4), + nn.ReLU(), + nn.Conv2d(32, 64, kernel_size=4, stride=2), + nn.ReLU(), + nn.Conv2d(64, 64, kernel_size=3, stride=1), + nn.ReLU(), + ) + + conv_out_size = self._get_conv_out(input_shape) + self.head = nn.Sequential( + nn.Linear(conv_out_size, 512), nn.ReLU(), nn.Linear(512, n_actions) + ) + + def _get_conv_out(self, shape) -> int: + """ + Calculates the output size of the last conv layer + Args: + shape: input dimensions + Returns: + size of the conv output + """ + conv_out = self.conv(torch.zeros(1, *shape)) + return int(np.prod(conv_out.size())) + + def forward(self, input_x) -> Tensor: + """ + Forward pass through network + Args: + x: input to network + Returns: + output of network + """ + conv_out = self.conv(input_x).view(input_x.size()[0], -1) + return self.head(conv_out) + + +class MLP(nn.Module): + """ + Simple MLP network + Args: + input_shape: observation shape of the environment + n_actions: number of discrete actions available in the environment + hidden_size: size of hidden layers + """ + + def __init__(self, input_shape: Tuple, n_actions: int, hidden_size: int = 128): + super(MLP, self).__init__() + self.net = nn.Sequential( + nn.Linear(input_shape[0], hidden_size), + nn.ReLU(), + nn.Linear(hidden_size, n_actions), + ) + + def forward(self, input_x): + """ + Forward pass through network + Args: + x: input to network + Returns: + output of network + """ + return self.net(input_x.float()) + + +class DuelingMLP(nn.Module): + """ + MLP network with duel heads for val and advantage + Args: + input_shape: observation shape of the environment + n_actions: number of discrete actions available in the environment + hidden_size: size of hidden layers + """ + + def __init__(self, input_shape: Tuple, n_actions: int, hidden_size: int = 128): + super(DuelingMLP, self).__init__() + + self.net = nn.Sequential( + nn.Linear(input_shape[0], hidden_size), + nn.ReLU(), + nn.Linear(hidden_size, hidden_size), + ) + + self.head_adv = nn.Sequential( + nn.Linear(hidden_size, hidden_size), + nn.ReLU(), + nn.Linear(hidden_size, n_actions), + ) + self.head_val = nn.Sequential( + nn.Linear(hidden_size, 256), nn.ReLU(), nn.Linear(256, 1) + ) + + def forward(self, input_x): + """ + Forward pass through network. Calculates the Q using the value and advantage + Args: + x: input to network + Returns: + Q value + """ + adv, val = self.adv_val(input_x) + q_val = val + (adv - adv.mean(dim=1, keepdim=True)) + return q_val + + def adv_val(self, input_x) -> Tuple[Tensor, Tensor]: + """ + Gets the advantage and value by passing out of the base network through the + value and advantage heads + Args: + input_x: input to network + Returns: + advantage, value + """ + float_x = input_x.float() + base_out = self.net(float_x) + return self.fc_adv(base_out), self.fc_val(base_out) + + +class DuelingCNN(nn.Module): + """ + CNN network with duel heads for val and advantage + Args: + input_shape: observation shape of the environment + n_actions: number of discrete actions available in the environment + hidden_size: size of hidden layers + """ + + def __init__(self, input_shape: Tuple, n_actions: int, _: int = 128): + + super().__init__() + + self.conv = nn.Sequential( + nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4), + nn.ReLU(), + nn.Conv2d(32, 64, kernel_size=4, stride=2), + nn.ReLU(), + nn.Conv2d(64, 64, kernel_size=3, stride=1), + nn.ReLU(), + ) + + conv_out_size = self._get_conv_out(input_shape) + + # advantage head + self.head_adv = nn.Sequential( + nn.Linear(conv_out_size, 256), nn.ReLU(), nn.Linear(256, n_actions) + ) + + # value head + self.head_val = nn.Sequential( + nn.Linear(conv_out_size, 256), nn.ReLU(), nn.Linear(256, 1) + ) + + def _get_conv_out(self, shape) -> int: + """ + Calculates the output size of the last conv layer + Args: + shape: input dimensions + Returns: + size of the conv output + """ + conv_out = self.conv(torch.zeros(1, *shape)) + return int(np.prod(conv_out.size())) + + def forward(self, input_x): + """ + Forward pass through network. Calculates the Q using the value and advantage + Args: + input_x: input to network + Returns: + Q value + """ + adv, val = self.adv_val(input_x) + q_val = val + (adv - adv.mean(dim=1, keepdim=True)) + return q_val + + def adv_val(self, input_x): + """ + Gets the advantage and value by passing out of the base network through the + value and advantage heads + Args: + input_x: input to network + Returns: + advantage, value + """ + float_x = input_x.float() + base_out = self.conv(input_x).view(float_x.size()[0], -1) + return self.head_adv(base_out), self.head_val(base_out) + + +class NoisyCNN(nn.Module): + """ + CNN with Noisy Linear layers for exploration + Args: + input_shape: observation shape of the environment + n_actions: number of discrete actions available in the environment + """ + + def __init__(self, input_shape, n_actions): + super().__init__() + + self.conv = nn.Sequential( + nn.Conv2d(input_shape[0], 32, kernel_size=8, stride=4), + nn.ReLU(), + nn.Conv2d(32, 64, kernel_size=4, stride=2), + nn.ReLU(), + nn.Conv2d(64, 64, kernel_size=3, stride=1), + nn.ReLU(), + ) + + conv_out_size = self._get_conv_out(input_shape) + self.head = nn.Sequential( + NoisyLinear(conv_out_size, 512), nn.ReLU(), NoisyLinear(512, n_actions) + ) + + def _get_conv_out(self, shape) -> int: + """ + Calculates the output size of the last conv layer + Args: + shape: input dimensions + Returns: + size of the conv output + """ + conv_out = self.conv(torch.zeros(1, *shape)) + return int(np.prod(conv_out.size())) + + def forward(self, input_x) -> Tensor: + """ + Forward pass through network + Args: + x: input to network + Returns: + output of network + """ + conv_out = self.conv(input_x).view(input_x.size()[0], -1) + return self.head(conv_out) + + +################### +# Custom Layers # +################### + + +class NoisyLinear(nn.Linear): + """ + Noisy Layer using Independent Gaussian Noise. + based on https://github.com/PacktPublishing/Deep-Reinforcement-Learning-Hands-On-Second-Edition/blob/master/ + Chapter08/lib/dqn_extra.py#L19 + Args: + in_features: number of inputs + out_features: number of outputs + sigma_init: initial fill value of noisy weights + bias: flag to include bias to linear layer + """ + + def __init__(self, in_features, out_features, sigma_init=0.017, bias=True): + super(NoisyLinear, self).__init__(in_features, out_features, bias=bias) + + weights = torch.full((out_features, in_features), sigma_init) + self.sigma_weight = nn.Parameter(weights) + epsilon_weight = torch.zeros(out_features, in_features) + self.register_buffer("epsilon_weight", epsilon_weight) + + if bias: + bias = torch.full((out_features,), sigma_init) + self.sigma_bias = nn.Parameter(bias) + epsilon_bias = torch.zeros(out_features) + self.register_buffer("epsilon_bias", epsilon_bias) + + self.reset_parameters() + + def reset_parameters(self) -> None: + """initializes or resets the paramseter of the layer""" + std = math.sqrt(3 / self.in_features) + self.weight.data.uniform_(-std, std) + self.bias.data.uniform_(-std, std) + + def forward(self, input_x: Tensor) -> Tensor: + """ + Forward pass of the layer + Args: + input_x: input tensor + Returns: + output of the layer + """ + self.epsilon_weight.normal_() + bias = self.bias + if bias is not None: + self.epsilon_bias.normal_() + bias = bias + self.sigma_bias * self.epsilon_bias.data + + noisy_weights = self.sigma_weight * self.epsilon_weight.data + self.weight + + return F.linear(input_x, noisy_weights, bias) diff --git a/pl_bolts/models/rl/double_dqn_model.py b/pl_bolts/models/rl/double_dqn_model.py new file mode 100644 index 0000000000..f31ae16c6d --- /dev/null +++ b/pl_bolts/models/rl/double_dqn_model.py @@ -0,0 +1,123 @@ +""" +Double DQN +""" +import argparse +from collections import OrderedDict +from typing import Tuple + +import pytorch_lightning as pl +import torch + +from pl_bolts.losses.rl import double_dqn_loss +from pl_bolts.models.rl.dqn_model import DQN + + +class DoubleDQN(DQN): + """ + Double Deep Q-network (DDQN) + PyTorch Lightning implementation of `Double DQN `_ + + Paper authors: Hado van Hasselt, Arthur Guez, David Silver + + Model implemented by: + + - `Donal Byrne ` + + Example: + + >>> from pl_bolts.models.rl.double_dqn_model import DoubleDQN + ... + >>> model = DoubleDQN("PongNoFrameskip-v4") + + Train:: + + trainer = Trainer() + trainer.fit(model) + + Args: + env: gym environment tag + gpus: number of gpus being used + eps_start: starting value of epsilon for the epsilon-greedy exploration + eps_end: final value of epsilon for the epsilon-greedy exploration + eps_last_frame: the final frame in for the decrease of epsilon. At this frame espilon = eps_end + sync_rate: the number of iterations between syncing up the target network with the train network + gamma: discount factor + lr: learning rate + batch_size: size of minibatch pulled from the DataLoader + replay_size: total capacity of the replay buffer + warm_start_size: how many random steps through the environment to be carried out at the start of + training to fill the buffer with a starting point + sample_len: the number of samples to pull from the dataset iterator and feed to the DataLoader + + Note: + This example is based on + https://github.com/PacktPublishing/Deep-Reinforcement-Learning-Hands-On-Second-Edition/blob/master/Chapter08/03_dqn_double.py + + Note: + Currently only supports CPU and single GPU training with `distributed_backend=dp` + """ + + def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], _) -> OrderedDict: + """ + Carries out a single step through the environment to update the replay buffer. + Then calculates loss based on the minibatch recieved + Args: + batch: current mini batch of replay data + _: batch number, not used + Returns: + Training loss and log metrics + """ + + # calculates training loss + loss = double_dqn_loss(batch, self.net, self.target_net) + + if self.trainer.use_dp or self.trainer.use_ddp2: + loss = loss.unsqueeze(0) + + # Soft update of target network + if self.global_step % self.sync_rate == 0: + self.target_net.load_state_dict(self.net.state_dict()) + + log = { + "total_reward": self.total_rewards[-1], + "avg_reward": self.avg_rewards, + "train_loss": loss, + # "episodes": self.total_episode_steps, + } + status = { + "steps": self.global_step, + "avg_reward": self.avg_rewards, + "total_reward": self.total_rewards[-1], + "episodes": self.done_episodes, + # "episode_steps": self.episode_steps, + "epsilon": self.agent.epsilon, + } + + return OrderedDict( + { + "loss": loss, + "avg_reward": self.avg_rewards, + "log": log, + "progress_bar": status, + } + ) + + +def cli_main(): + parser = argparse.ArgumentParser(add_help=False) + + # trainer args + parser = pl.Trainer.add_argparse_args(parser) + + # model args + parser = DoubleDQN.add_model_specific_args(parser) + args = parser.parse_args() + + model = DoubleDQN(**args.__dict__) + + trainer = pl.Trainer.from_argparse_args(args) + trainer.fit(model) + + +if __name__ == '__main__': + cli_main() diff --git a/pl_bolts/models/rl/dqn_model.py b/pl_bolts/models/rl/dqn_model.py new file mode 100644 index 0000000000..52de30e975 --- /dev/null +++ b/pl_bolts/models/rl/dqn_model.py @@ -0,0 +1,444 @@ +""" +Deep Q Network +""" + +import argparse +from collections import OrderedDict +from typing import Tuple, List, Dict +from warnings import warn + +import numpy as np +import pytorch_lightning as pl +import torch +import torch.optim as optim +from pytorch_lightning import seed_everything +from pytorch_lightning.callbacks import ModelCheckpoint +from torch.optim.optimizer import Optimizer +from torch.utils.data import DataLoader + +from pl_bolts.datamodules.experience_source import ExperienceSourceDataset, Experience +from pl_bolts.losses.rl import dqn_loss +from pl_bolts.models.rl.common.agents import ValueAgent +from pl_bolts.models.rl.common.memory import MultiStepBuffer +from pl_bolts.models.rl.common.networks import CNN +try: + from pl_bolts.models.rl.common.gym_wrappers import gym, make_environment +except ModuleNotFoundError: + warn('You want to use `gym` which is not installed yet,' # pragma: no-cover + ' install it with `pip install gym`.') + _GYM_AVAILABLE = False +else: + _GYM_AVAILABLE = True + + +class DQN(pl.LightningModule): + """ Basic DQN Model """ + + def __init__( + self, + env: str, + eps_start: float = 1.0, + eps_end: float = 0.02, + eps_last_frame: int = 150000, + sync_rate: int = 1000, + gamma: float = 0.99, + learning_rate: float = 1e-4, + batch_size: int = 32, + replay_size: int = 100000, + warm_start_size: int = 10000, + avg_reward_len: int = 100, + min_episode_reward: int = -21, + seed: int = 123, + batches_per_epoch: int = 1000, + n_steps: int = 1, + **kwargs, + ): + """ + PyTorch Lightning implementation of `DQN `_ + Paper authors: Volodymyr Mnih, Koray Kavukcuoglu, David Silver, Alex Graves, + Ioannis Antonoglou, Daan Wierstra, Martin Riedmiller. + Model implemented by: + + - `Donal Byrne ` + + Example: + >>> from pl_bolts.models.rl.dqn_model import DQN + ... + >>> model = DQN("PongNoFrameskip-v4") + + Train:: + + trainer = Trainer() + trainer.fit(model) + + Args: + env: gym environment tag + eps_start: starting value of epsilon for the epsilon-greedy exploration + eps_end: final value of epsilon for the epsilon-greedy exploration + eps_last_frame: the final frame in for the decrease of epsilon. At this frame espilon = eps_end + sync_rate: the number of iterations between syncing up the target network with the train network + gamma: discount factor + learning_rate: learning rate + batch_size: size of minibatch pulled from the DataLoader + replay_size: total capacity of the replay buffer + warm_start_size: how many random steps through the environment to be carried out at the start of + training to fill the buffer with a starting point + avg_reward_len: how many episodes to take into account when calculating the avg reward + min_episode_reward: the minimum score that can be achieved in an episode. Used for filling the avg buffer + before training begins + seed: seed value for all RNG used + batches_per_epoch: number of batches per epoch + n_steps: size of n step look ahead + + Note: + This example is based on: + https://github.com/PacktPublishing/Deep-Reinforcement-Learning-Hands-On-Second-Edition\ + /blob/master/Chapter06/02_dqn_pong.py + + Note: + Currently only supports CPU and single GPU training with `distributed_backend=dp` + """ + super().__init__() + + # Environment + self.exp = None + self.env = self.make_environment(env, seed) + self.test_env = self.make_environment(env) + + self.obs_shape = self.env.observation_space.shape + self.n_actions = self.env.action_space.n + + # Model Attributes + self.buffer = None + self.dataset = None + + self.net = None + self.target_net = None + self.build_networks() + + self.agent = ValueAgent( + self.net, + self.n_actions, + eps_start=eps_start, + eps_end=eps_end, + eps_frames=eps_last_frame, + ) + + # Hyperparameters + self.sync_rate = sync_rate + self.gamma = gamma + self.lr = learning_rate + self.batch_size = batch_size + self.replay_size = replay_size + self.warm_start_size = warm_start_size + self.batches_per_epoch = batches_per_epoch + self.n_steps = n_steps + + self.save_hyperparameters() + + # Metrics + self.total_episode_steps = [0] + self.total_rewards = [0] + self.done_episodes = 0 + self.total_steps = 0 + + # Average Rewards + self.avg_reward_len = avg_reward_len + + for _ in range(avg_reward_len): + self.total_rewards.append( + torch.tensor(min_episode_reward, device=self.device) + ) + + self.avg_rewards = float( + np.mean(self.total_rewards[-self.avg_reward_len:]) + ) + + self.state = self.env.reset() + + def run_n_episodes(self, env, n_epsiodes: int = 1, epsilon: float = 1.0) -> List[int]: + """ + Carries out N episodes of the environment with the current agent + Args: + env: environment to use, either train environment or test environment + n_epsiodes: number of episodes to run + epsilon: epsilon value for DQN agent + """ + total_rewards = [] + + for _ in range(n_epsiodes): + episode_state = env.reset() + done = False + episode_reward = 0 + + while not done: + self.agent.epsilon = epsilon + action = self.agent(episode_state, self.device) + next_state, reward, done, _ = self.env.step(action[0]) + episode_state = next_state + episode_reward += reward + + total_rewards.append(episode_reward) + + return total_rewards + + def populate(self, warm_start: int) -> None: + """Populates the buffer with initial experience""" + if warm_start > 0: + self.state = self.env.reset() + + for _ in range(warm_start): + self.agent.epsilon = 1.0 + action = self.agent(self.state, self.device) + next_state, reward, done, _ = self.env.step(action[0]) + exp = Experience(state=self.state, action=action[0], reward=reward, done=done, new_state=next_state) + self.buffer.append(exp) + self.state = next_state + + if done: + self.state = self.env.reset() + + def build_networks(self) -> None: + """Initializes the DQN train and target networks""" + self.net = CNN(self.obs_shape, self.n_actions) + self.target_net = CNN(self.obs_shape, self.n_actions) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Passes in a state x through the network and gets the q_values of each action as an output + Args: + x: environment state + Returns: + q values + """ + output = self.net(x) + return output + + def train_batch( + self, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Contains the logic for generating a new batch of data to be passed to the DataLoader + Returns: + yields a Experience tuple containing the state, action, reward, done and next_state. + """ + episode_reward = 0 + episode_steps = 0 + + while True: + self.total_steps += 1 + action = self.agent(self.state, self.device) + + next_state, r, is_done, _ = self.env.step(action[0]) + + episode_reward += r + episode_steps += 1 + + exp = Experience(state=self.state, action=action[0], reward=r, done=is_done, new_state=next_state) + + self.agent.update_epsilon(self.global_step) + self.buffer.append(exp) + self.state = next_state + + if is_done: + self.done_episodes += 1 + self.total_rewards.append(episode_reward) + self.total_episode_steps.append(episode_steps) + self.avg_rewards = float( + np.mean(self.total_rewards[-self.avg_reward_len:]) + ) + self.state = self.env.reset() + episode_steps = 0 + episode_reward = 0 + + states, actions, rewards, dones, new_states = self.buffer.sample(self.batch_size) + + for idx, _ in enumerate(dones): + yield states[idx], actions[idx], rewards[idx], dones[idx], new_states[idx] + + # Simulates epochs + if self.total_steps % self.batches_per_epoch == 0: + break + + def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], _) -> OrderedDict: + """ + Carries out a single step through the environment to update the replay buffer. + Then calculates loss based on the minibatch recieved + Args: + batch: current mini batch of replay data + _: batch number, not used + Returns: + Training loss and log metrics + """ + + # calculates training loss + loss = dqn_loss(batch, self.net, self.target_net) + + if self.trainer.use_dp or self.trainer.use_ddp2: + loss = loss.unsqueeze(0) + + # Soft update of target network + if self.global_step % self.sync_rate == 0: + self.target_net.load_state_dict(self.net.state_dict()) + + log = { + "total_reward": self.total_rewards[-1], + "avg_reward": self.avg_rewards, + "train_loss": loss, + "episodes": self.done_episodes, + "episode_steps": self.total_episode_steps[-1] + } + status = { + "steps": self.global_step, + "avg_reward": self.avg_rewards, + "total_reward": self.total_rewards[-1], + "episodes": self.done_episodes, + "episode_steps": self.total_episode_steps[-1], + "epsilon": self.agent.epsilon, + } + + return OrderedDict( + { + "loss": loss, + "avg_reward": self.avg_rewards, + "log": log, + "progress_bar": status, + } + ) + + def test_step(self, *args, **kwargs) -> Dict[str, torch.Tensor]: + """Evaluate the agent for 10 episodes""" + test_reward = self.run_n_episodes(self.test_env, 1, 0) + avg_reward = sum(test_reward) / len(test_reward) + return {"test_reward": avg_reward} + + def test_epoch_end(self, outputs) -> Dict[str, torch.Tensor]: + """Log the avg of the test results""" + rewards = [x["test_reward"] for x in outputs] + avg_reward = sum(rewards) / len(rewards) + tensorboard_logs = {"avg_test_reward": avg_reward} + return {"avg_test_reward": avg_reward, "log": tensorboard_logs} + + def configure_optimizers(self) -> List[Optimizer]: + """ Initialize Adam optimizer""" + optimizer = optim.Adam(self.net.parameters(), lr=self.lr) + return [optimizer] + + def _dataloader(self) -> DataLoader: + """Initialize the Replay Buffer dataset used for retrieving experiences""" + self.buffer = MultiStepBuffer(self.replay_size, self.n_steps) + self.populate(self.warm_start_size) + + self.dataset = ExperienceSourceDataset(self.train_batch) + return DataLoader(dataset=self.dataset, batch_size=self.batch_size) + + def train_dataloader(self) -> DataLoader: + """Get train loader""" + return self._dataloader() + + def test_dataloader(self) -> DataLoader: + """Get test loader""" + return self._dataloader() + + @staticmethod + def make_environment(env_name: str, seed: int = None) -> gym.Env: + """ + Initialise gym environment + Args: + env_name: environment name or tag + seed: value to seed the environment RNG for reproducibility + Returns: + gym environment + """ + env = make_environment(env_name) + + if seed: + env.seed(seed) + + return env + + @staticmethod + def add_model_specific_args( + arg_parser: argparse.ArgumentParser, + ) -> argparse.ArgumentParser: + """ + Adds arguments for DQN model + Note: these params are fine tuned for Pong env + Args: + arg_parser: parent parser + """ + arg_parser.add_argument( + "--sync_rate", + type=int, + default=1000, + help="how many frames do we update the target network", + ) + arg_parser.add_argument( + "--replay_size", + type=int, + default=100000, + help="capacity of the replay buffer", + ) + arg_parser.add_argument( + "--warm_start_size", + type=int, + default=10000, + help="how many samples do we use to fill our buffer at the start of training", + ) + arg_parser.add_argument( + "--eps_last_frame", + type=int, + default=150000, + help="what frame should epsilon stop decaying", + ) + arg_parser.add_argument("--eps_start", type=float, default=1.0, help="starting value of epsilon") + arg_parser.add_argument("--eps_end", type=float, default=0.02, help="final value of epsilon") + arg_parser.add_argument("--batches_per_epoch", type=int, default=10000, help="number of batches in an epoch") + arg_parser.add_argument("--batch_size", type=int, default=32, help="size of the batches") + arg_parser.add_argument("--lr", type=float, default=1e-4, help="learning rate") + + arg_parser.add_argument("--env", type=str, required=True, help="gym environment tag") + arg_parser.add_argument("--gamma", type=float, default=0.99, help="discount factor") + + arg_parser.add_argument( + "--avg_reward_len", + type=int, + default=100, + help="how many episodes to include in avg reward", + ) + arg_parser.add_argument( + "--n_steps", + type=int, + default=1, + help="how many frames do we update the target network", + ) + + return arg_parser + + +def cli_main(): + parser = argparse.ArgumentParser(add_help=False) + + # trainer args + parser = pl.Trainer.add_argparse_args(parser) + + # model args + parser = DQN.add_model_specific_args(parser) + args = parser.parse_args() + + model = DQN(**args.__dict__) + + # save checkpoints based on avg_reward + checkpoint_callback = ModelCheckpoint( + save_top_k=1, monitor="avg_reward", mode="max", period=1, verbose=True + ) + + seed_everything(123) + trainer = pl.Trainer.from_argparse_args( + args, deterministic=True, checkpoint_callback=checkpoint_callback) + + trainer.fit(model) + + +if __name__ == '__main__': + cli_main() diff --git a/pl_bolts/models/rl/dueling_dqn_model.py b/pl_bolts/models/rl/dueling_dqn_model.py new file mode 100644 index 0000000000..79afca2fc7 --- /dev/null +++ b/pl_bolts/models/rl/dueling_dqn_model.py @@ -0,0 +1,75 @@ +""" +Dueling DQN +""" +import argparse + +import pytorch_lightning as pl + +from pl_bolts.models.rl.common.networks import DuelingCNN +from pl_bolts.models.rl.dqn_model import DQN + + +class DuelingDQN(DQN): + """ + PyTorch Lightning implementation of `Dueling DQN `_ + + Paper authors: Ziyu Wang, Tom Schaul, Matteo Hessel, Hado van Hasselt, Marc Lanctot, Nando de Freitas + + Model implemented by: + + - `Donal Byrne ` + + Example: + + >>> from pl_bolts.models.rl.dueling_dqn_model import DuelingDQN + ... + >>> model = DuelingDQN("PongNoFrameskip-v4") + + Train:: + + trainer = Trainer() + trainer.fit(model) + + Args: + env: gym environment tag + gpus: number of gpus being used + eps_start: starting value of epsilon for the epsilon-greedy exploration + eps_end: final value of epsilon for the epsilon-greedy exploration + eps_last_frame: the final frame in for the decrease of epsilon. At this frame espilon = eps_end + sync_rate: the number of iterations between syncing up the target network with the train network + gamma: discount factor + lr: learning rate + batch_size: size of minibatch pulled from the DataLoader + replay_size: total capacity of the replay buffer + warm_start_size: how many random steps through the environment to be carried out at the start of + training to fill the buffer with a starting point + sample_len: the number of samples to pull from the dataset iterator and feed to the DataLoader + + .. note:: Currently only supports CPU and single GPU training with `distributed_backend=dp` + + """ + + def build_networks(self) -> None: + """Initializes the Dueling DQN train and target networks""" + self.net = DuelingCNN(self.obs_shape, self.n_actions) + self.target_net = DuelingCNN(self.obs_shape, self.n_actions) + + +def cli_main(): + parser = argparse.ArgumentParser(add_help=False) + + # trainer args + parser = pl.Trainer.add_argparse_args(parser) + + # model args + parser = DuelingDQN.add_model_specific_args(parser) + args = parser.parse_args() + + model = DuelingDQN(**args.__dict__) + + trainer = pl.Trainer.from_argparse_args(args) + trainer.fit(model) + + +if __name__ == '__main__': + cli_main() diff --git a/pl_bolts/models/rl/noisy_dqn_model.py b/pl_bolts/models/rl/noisy_dqn_model.py new file mode 100644 index 0000000000..26f960c117 --- /dev/null +++ b/pl_bolts/models/rl/noisy_dqn_model.py @@ -0,0 +1,130 @@ +""" +Noisy DQN +""" +import argparse +from typing import Tuple + +import numpy as np +import pytorch_lightning as pl +import torch + +from pl_bolts.datamodules.experience_source import Experience +from pl_bolts.models.rl.common.networks import NoisyCNN +from pl_bolts.models.rl.dqn_model import DQN + + +class NoisyDQN(DQN): + """ + PyTorch Lightning implementation of `Noisy DQN `_ + + Paper authors: Meire Fortunato, Mohammad Gheshlaghi Azar, Bilal Piot, Jacob Menick, Ian Osband, Alex Graves, + Vlad Mnih, Remi Munos, Demis Hassabis, Olivier Pietquin, Charles Blundell, Shane Legg + + Model implemented by: + + - `Donal Byrne ` + + Example: + >>> from pl_bolts.models.rl.noisy_dqn_model import NoisyDQN + ... + >>> model = NoisyDQN("PongNoFrameskip-v4") + + Train:: + + trainer = Trainer() + trainer.fit(model) + + Args: + env: gym environment tag + gpus: number of gpus being used + eps_start: starting value of epsilon for the epsilon-greedy exploration + eps_end: final value of epsilon for the epsilon-greedy exploration + eps_last_frame: the final frame in for the decrease of epsilon. At this frame espilon = eps_end + sync_rate: the number of iterations between syncing up the target network with the train network + gamma: discount factor + lr: learning rate + batch_size: size of minibatch pulled from the DataLoader + replay_size: total capacity of the replay buffer + warm_start_size: how many random steps through the environment to be carried out at the start of + training to fill the buffer with a starting point + sample_len: the number of samples to pull from the dataset iterator and feed to the DataLoader + + .. note:: Currently only supports CPU and single GPU training with `distributed_backend=dp` + + """ + + def build_networks(self) -> None: + """Initializes the Noisy DQN train and target networks""" + self.net = NoisyCNN(self.obs_shape, self.n_actions) + self.target_net = NoisyCNN(self.obs_shape, self.n_actions) + + def on_train_start(self) -> None: + """Set the agents epsilon to 0 as the exploration comes from the network""" + self.agent.epsilon = 0.0 + + def train_batch( + self, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Contains the logic for generating a new batch of data to be passed to the DataLoader. + This is the same function as the standard DQN except that we dont update epsilon as it is always 0. The + exploration comes from the noisy network. + Returns: + yields a Experience tuple containing the state, action, reward, done and next_state. + """ + episode_reward = 0 + episode_steps = 0 + + while True: + self.total_steps += 1 + action = self.agent(self.state, self.device) + + next_state, r, is_done, _ = self.env.step(action[0]) + + episode_reward += r + episode_steps += 1 + + exp = Experience(state=self.state, action=action[0], reward=r, done=is_done, new_state=next_state) + + self.buffer.append(exp) + self.state = next_state + + if is_done: + self.done_episodes += 1 + self.total_rewards.append(episode_reward) + self.total_episode_steps.append(episode_steps) + self.avg_rewards = float( + np.mean(self.total_rewards[-self.avg_reward_len:]) + ) + self.state = self.env.reset() + episode_steps = 0 + episode_reward = 0 + + states, actions, rewards, dones, new_states = self.buffer.sample(self.batch_size) + + for idx, _ in enumerate(dones): + yield states[idx], actions[idx], rewards[idx], dones[idx], new_states[idx] + + # Simulates epochs + if self.total_steps % self.batches_per_epoch == 0: + break + + +def cli_main(): + parser = argparse.ArgumentParser(add_help=False) + + # trainer args + parser = pl.Trainer.add_argparse_args(parser) + + # model args + parser = NoisyDQN.add_model_specific_args(parser) + args = parser.parse_args() + + model = NoisyDQN(**args.__dict__) + + trainer = pl.Trainer.from_argparse_args(args) + trainer.fit(model) + + +if __name__ == '__main__': + cli_main() diff --git a/pl_bolts/models/rl/per_dqn_model.py b/pl_bolts/models/rl/per_dqn_model.py new file mode 100644 index 0000000000..7705a6846a --- /dev/null +++ b/pl_bolts/models/rl/per_dqn_model.py @@ -0,0 +1,197 @@ +""" +Prioritized Experience Replay DQN +""" +import argparse +from collections import OrderedDict +from typing import Tuple + +import numpy as np +import pytorch_lightning as pl +import torch +from torch.utils.data import DataLoader + +from pl_bolts.datamodules import ExperienceSourceDataset +from pl_bolts.losses.rl import per_dqn_loss +from pl_bolts.models.rl.common.memory import PERBuffer, Experience +from pl_bolts.models.rl.dqn_model import DQN + + +class PERDQN(DQN): + """ + PyTorch Lightning implementation of `DQN With Prioritized Experience Replay `_ + + Paper authors: Tom Schaul, John Quan, Ioannis Antonoglou, David Silver + + Model implemented by: + + - `Donal Byrne ` + + Example: + + >>> from pl_bolts.models.rl.per_dqn_model import PERDQN + ... + >>> model = PERDQN("PongNoFrameskip-v4") + + Train:: + + trainer = Trainer() + trainer.fit(model) + + Args: + env: gym environment tag + gpus: number of gpus being used + eps_start: starting value of epsilon for the epsilon-greedy exploration + eps_end: final value of epsilon for the epsilon-greedy exploration + eps_last_frame: the final frame in for the decrease of epsilon. At this frame espilon = eps_end + sync_rate: the number of iterations between syncing up the target network with the train network + gamma: discount factor + learning_rate: learning rate + batch_size: size of minibatch pulled from the DataLoader + replay_size: total capacity of the replay buffer + warm_start_size: how many random steps through the environment to be carried out at the start of + training to fill the buffer with a starting point + num_samples: the number of samples to pull from the dataset iterator and feed to the DataLoader + + .. note:: + This example is based on: + https://github.com/PacktPublishing/Deep-Reinforcement-Learning-Hands-On-Second-Edition\ + /blob/master/Chapter08/05_dqn_prio_replay.py + + .. note:: Currently only supports CPU and single GPU training with `distributed_backend=dp` + + """ + + def train_batch( + self, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Contains the logic for generating a new batch of data to be passed to the DataLoader + Returns: + yields a Experience tuple containing the state, action, reward, done and next_state. + """ + + episode_reward = 0 + episode_steps = 0 + + while True: + self.total_steps += 1 + action = self.agent(self.state, self.device) + + next_state, r, is_done, _ = self.env.step(action[0]) + + episode_reward += r + episode_steps += 1 + + exp = Experience( + state=self.state, + action=action[0], + reward=r, + done=is_done, + new_state=next_state, + ) + + self.agent.update_epsilon(self.global_step) + self.buffer.append(exp) + self.state = next_state + + if is_done: + self.done_episodes += 1 + self.total_rewards.append(episode_reward) + self.total_episode_steps.append(episode_steps) + self.avg_rewards = float( + np.mean(self.total_rewards[-self.avg_reward_len:]) + ) + self.state = self.env.reset() + episode_steps = 0 + episode_reward = 0 + + samples, indices, weights = self.buffer.sample(self.batch_size) + + states, actions, rewards, dones, new_states = samples + + for idx, _ in enumerate(dones): + yield ( + states[idx], + actions[idx], + rewards[idx], + dones[idx], + new_states[idx], + ), indices[idx], weights[idx] + + def training_step(self, batch, _) -> OrderedDict: + """ + Carries out a single step through the environment to update the replay buffer. + Then calculates loss based on the minibatch recieved + Args: + batch: current mini batch of replay data + _: batch number, not used + Returns: + Training loss and log metrics + """ + samples, indices, weights = batch + indices = indices.cpu().numpy() + + # calculates training loss + loss, batch_weights = per_dqn_loss(samples, weights, self.net, self.target_net) + + if self.trainer.use_dp or self.trainer.use_ddp2: + loss = loss.unsqueeze(0) + + # update priorities in buffer + self.buffer.update_priorities(indices, batch_weights) + + # update of target network + if self.global_step % self.sync_rate == 0: + self.target_net.load_state_dict(self.net.state_dict()) + + log = { + "total_reward": self.total_rewards[-1], + "avg_reward": self.avg_rewards, + "train_loss": loss, + # "episodes": self.total_episode_steps, + } + status = { + "steps": self.global_step, + "avg_reward": self.avg_rewards, + "total_reward": self.total_rewards[-1], + "episodes": self.done_episodes, + # "episode_steps": self.episode_steps, + "epsilon": self.agent.epsilon, + } + + return OrderedDict( + { + "loss": loss, + "avg_reward": self.avg_rewards, + "log": log, + "progress_bar": status, + } + ) + + def _dataloader(self) -> DataLoader: + """Initialize the Replay Buffer dataset used for retrieving experiences""" + self.buffer = PERBuffer(self.replay_size) + self.populate(self.warm_start_size) + + self.dataset = ExperienceSourceDataset(self.train_batch) + return DataLoader(dataset=self.dataset, batch_size=self.batch_size) + + +def cli_main(): + parser = argparse.ArgumentParser(add_help=False) + + # trainer args + parser = pl.Trainer.add_argparse_args(parser) + + # model args + parser = PERDQN.add_model_specific_args(parser) + args = parser.parse_args() + + model = PERDQN(**args.__dict__) + + trainer = pl.Trainer.from_argparse_args(args) + trainer.fit(model) + + +if __name__ == "__main__": + cli_main() diff --git a/pl_bolts/models/rl/reinforce_model.py b/pl_bolts/models/rl/reinforce_model.py new file mode 100644 index 0000000000..55535a91e7 --- /dev/null +++ b/pl_bolts/models/rl/reinforce_model.py @@ -0,0 +1,318 @@ +import argparse +from collections import OrderedDict +from typing import Tuple, List +from warnings import warn + +import numpy as np +import pytorch_lightning as pl +import torch +import torch.optim as optim +from pytorch_lightning import seed_everything +from pytorch_lightning.callbacks import ModelCheckpoint +from torch.nn.functional import log_softmax +from torch.optim.optimizer import Optimizer +from torch.utils.data import DataLoader + +from pl_bolts.datamodules import ExperienceSourceDataset +from pl_bolts.datamodules.experience_source import Experience +from pl_bolts.models.rl.common.agents import PolicyAgent +from pl_bolts.models.rl.common.networks import MLP +try: + import gym +except ModuleNotFoundError: + warn('You want to use `gym` which is not installed yet, install it with `pip install gym`.') # pragma: no-cover + _GYM_AVAILABLE = False +else: + _GYM_AVAILABLE = True + + +class Reinforce(pl.LightningModule): + def __init__( + self, + env: str, + gamma: float = 0.99, + lr: float = 0.01, + batch_size: int = 8, + n_steps: int = 10, + avg_reward_len: int = 100, + entropy_beta: float = 0.01, + epoch_len: int = 1000, + num_batch_episodes: int = 4, + **kwargs + ) -> None: + """ + PyTorch Lightning implementation of `REINFORCE + `_ + Paper authors: Richard S. Sutton, David McAllester, Satinder Singh, Yishay Mansour + Model implemented by: + + - `Donal Byrne ` + + Example: + >>> from pl_bolts.models.rl.reinforce_model import Reinforce + ... + >>> model = Reinforce("CartPole-v0") + + Train:: + + trainer = Trainer() + trainer.fit(model) + + Args: + env: gym environment tag + gamma: discount factor + lr: learning rate + batch_size: size of minibatch pulled from the DataLoader + n_steps: number of stakes per discounted experience + entropy_beta: entropy coefficient + epoch_len: how many batches before pseudo epoch + num_batch_episodes: how many episodes to rollout for each batch of training + avg_reward_len: how many episodes to take into account when calculating the avg reward + + Note: + This example is based on: + https://github.com/PacktPublishing/Deep-Reinforcement-Learning-Hands-On-Second-Edition/blob/master/Chapter11/02_cartpole_reinforce.py + + Note: + Currently only supports CPU and single GPU training with `distributed_backend=dp` + """ + super().__init__() + + if not _GYM_AVAILABLE: + raise ModuleNotFoundError('This Module requires gym environment which is not installed yet.') + + # Hyperparameters + self.lr = lr + self.batch_size = batch_size + self.batches_per_epoch = self.batch_size * epoch_len + self.entropy_beta = entropy_beta + self.gamma = gamma + self.n_steps = n_steps + self.num_batch_episodes = num_batch_episodes + + self.save_hyperparameters() + + # Model components + self.env = gym.make(env) + self.net = MLP(self.env.observation_space.shape, self.env.action_space.n) + self.agent = PolicyAgent(self.net) + + # Tracking metrics + self.total_steps = 0 + self.total_rewards = [0] + self.done_episodes = 0 + self.avg_rewards = 0 + self.reward_sum = 0.0 + self.batch_episodes = 0 + self.avg_reward_len = avg_reward_len + + self.batch_states = [] + self.batch_actions = [] + self.batch_qvals = [] + self.cur_rewards = [] + + self.state = self.env.reset() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Passes in a state x through the network and gets the q_values of each action as an output + Args: + x: environment state + Returns: + q values + """ + output = self.net(x) + return output + + def calc_qvals(self, rewards: List[float]) -> List[float]: + """Calculate the discounted rewards of all rewards in list + Args: + rewards: list of rewards from latest batch + Returns: + list of discounted rewards + """ + assert isinstance(rewards[0], float) + + cumul_reward = [] + sum_r = 0.0 + + for r in reversed(rewards): + sum_r = (sum_r * self.gamma) + r + cumul_reward.append(sum_r) + + return list(reversed(cumul_reward)) + + def discount_rewards(self, experiences: Tuple[Experience]) -> float: + """ + Calculates the discounted reward over N experiences + Args: + experiences: Tuple of Experience + Returns: + total discounted reward + """ + total_reward = 0.0 + for exp in reversed(experiences): + total_reward = (self.gamma * total_reward) + exp.reward + return total_reward + + def train_batch( + self, + ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]: + """ + Contains the logic for generating a new batch of data to be passed to the DataLoader + Yield: + yields a tuple of Lists containing tensors for states, actions and rewards of the batch. + """ + + while True: + + action = self.agent(self.state, self.device) + + next_state, reward, done, _ = self.env.step(action[0]) + + self.batch_states.append(self.state) + self.batch_actions.append(action[0]) + self.cur_rewards.append(reward) + + self.state = next_state + self.total_steps += 1 + + if done: + self.batch_qvals.extend(self.calc_qvals(self.cur_rewards)) + self.batch_episodes += 1 + self.done_episodes += 1 + self.total_rewards.append(sum(self.cur_rewards)) + self.avg_rewards = float( + np.mean(self.total_rewards[-self.avg_reward_len:]) + ) + self.cur_rewards = [] + self.state = self.env.reset() + + if self.batch_episodes >= self.num_batch_episodes: + for state, action, qval in zip( + self.batch_states, self.batch_actions, self.batch_qvals + ): + yield state, action, qval + + self.batch_episodes = 0 + + # Simulates epochs + if self.total_steps % self.batches_per_epoch == 0: + break + + def loss(self, states, actions, scaled_rewards) -> torch.Tensor: + logits = self.net(states) + + # policy loss + log_prob = log_softmax(logits, dim=1) + log_prob_actions = scaled_rewards * log_prob[range(self.batch_size), actions] + loss = -log_prob_actions.mean() + + return loss + + def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], _) -> OrderedDict: + """ + Carries out a single step through the environment to update the replay buffer. + Then calculates loss based on the minibatch recieved + Args: + batch: current mini batch of replay data + _: batch number, not used + Returns: + Training loss and log metrics + """ + states, actions, scaled_rewards = batch + + loss = self.loss(states, actions, scaled_rewards) + + log = { + "episodes": self.done_episodes, + "reward": self.total_rewards[-1], + "avg_reward": self.avg_rewards, + } + + return OrderedDict( + { + "loss": loss, + "avg_reward": self.avg_rewards, + "log": log, + "progress_bar": log, + } + ) + + def configure_optimizers(self) -> List[Optimizer]: + """ Initialize Adam optimizer""" + optimizer = optim.Adam(self.net.parameters(), lr=self.lr) + return [optimizer] + + def _dataloader(self) -> DataLoader: + """Initialize the Replay Buffer dataset used for retrieving experiences""" + dataset = ExperienceSourceDataset(self.train_batch) + dataloader = DataLoader(dataset=dataset, batch_size=self.batch_size) + return dataloader + + def train_dataloader(self) -> DataLoader: + """Get train loader""" + return self._dataloader() + + def get_device(self, batch) -> str: + """Retrieve device currently being used by minibatch""" + return batch[0][0][0].device.index if self.on_gpu else "cpu" + + @staticmethod + def add_model_specific_args(arg_parser) -> argparse.ArgumentParser: + """ + Adds arguments for DQN model + Note: these params are fine tuned for Pong env + Args: + arg_parser: the current argument parser to add to + Returns: + arg_parser with model specific cargs added + """ + arg_parser.add_argument("--batches_per_epoch", type=int, default=10000, help="number of batches in an epoch") + arg_parser.add_argument("--batch_size", type=int, default=32, help="size of the batches") + arg_parser.add_argument("--lr", type=float, default=1e-3, help="learning rate") + + arg_parser.add_argument("--env", type=str, required=True, help="gym environment tag") + arg_parser.add_argument("--gamma", type=float, default=0.99, help="discount factor") + + arg_parser.add_argument( + "--avg_reward_len", + type=int, + default=100, + help="how many episodes to include in avg reward", + ) + + arg_parser.add_argument( + "--entropy_beta", type=float, default=0.01, help="entropy value", + ) + + return arg_parser + + +def cli_main(): + parser = argparse.ArgumentParser(add_help=False) + + # trainer args + parser = pl.Trainer.add_argparse_args(parser) + + # model args + parser = Reinforce.add_model_specific_args(parser) + args = parser.parse_args() + + model = Reinforce(**args.__dict__) + + # save checkpoints based on avg_reward + checkpoint_callback = ModelCheckpoint( + save_top_k=1, monitor="avg_reward", mode="max", period=1, verbose=True + ) + + seed_everything(123) + trainer = pl.Trainer.from_argparse_args( + args, deterministic=True, checkpoint_callback=checkpoint_callback + ) + trainer.fit(model) + + +if __name__ == '__main__': + cli_main() diff --git a/pl_bolts/models/rl/vanilla_policy_gradient_model.py b/pl_bolts/models/rl/vanilla_policy_gradient_model.py new file mode 100644 index 0000000000..f7d9e6586f --- /dev/null +++ b/pl_bolts/models/rl/vanilla_policy_gradient_model.py @@ -0,0 +1,306 @@ +import argparse +from collections import OrderedDict +from typing import Tuple, List +from warnings import warn + +import numpy as np +import pytorch_lightning as pl +import torch +import torch.optim as optim +from pytorch_lightning import seed_everything +from pytorch_lightning.callbacks import ModelCheckpoint +from torch.nn.functional import log_softmax, softmax +from torch.optim.optimizer import Optimizer +from torch.utils.data import DataLoader + +from pl_bolts.datamodules import ExperienceSourceDataset +from pl_bolts.models.rl.common.agents import PolicyAgent +from pl_bolts.models.rl.common.networks import MLP +try: + import gym +except ModuleNotFoundError: + warn('You want to use `gym` which is not installed yet, install it with `pip install gym`.') # pragma: no-cover + _GYM_AVAILABLE = False +else: + _GYM_AVAILABLE = True + + +class VanillaPolicyGradient(pl.LightningModule): + def __init__( + self, + env: str, + gamma: float = 0.99, + lr: float = 0.01, + batch_size: int = 8, + n_steps: int = 10, + avg_reward_len: int = 100, + entropy_beta: float = 0.01, + epoch_len: int = 1000, + **kwargs + ) -> None: + """ + PyTorch Lightning implementation of `Vanilla Policy Gradient + `_ + Paper authors: Richard S. Sutton, David McAllester, Satinder Singh, Yishay Mansour + Model implemented by: + + - `Donal Byrne ` + + Example: + >>> from pl_bolts.models.rl.vanilla_policy_gradient_model import VanillaPolicyGradient + ... + >>> model = VanillaPolicyGradient("CartPole-v0") + + Train:: + trainer = Trainer() + trainer.fit(model) + + Args: + env: gym environment tag + gamma: discount factor + lr: learning rate + batch_size: size of minibatch pulled from the DataLoader + batch_episodes: how many episodes to rollout for each batch of training + entropy_beta: dictates the level of entropy per batch + avg_reward_len: how many episodes to take into account when calculating the avg reward + + Note: + This example is based on: + https://github.com/PacktPublishing/Deep-Reinforcement-Learning-Hands-On-Second-Edition/blob/master/Chapter11/04_cartpole_pg.py + + Note: + Currently only supports CPU and single GPU training with `distributed_backend=dp` + """ + super().__init__() + + if not _GYM_AVAILABLE: + raise ModuleNotFoundError('This Module requires gym environment which is not installed yet.') + + # Hyperparameters + self.lr = lr + self.batch_size = batch_size + self.batches_per_epoch = self.batch_size * epoch_len + self.entropy_beta = entropy_beta + self.gamma = gamma + self.n_steps = n_steps + + self.save_hyperparameters() + + # Model components + self.env = gym.make(env) + self.net = MLP(self.env.observation_space.shape, self.env.action_space.n) + self.agent = PolicyAgent(self.net) + + # Tracking metrics + self.total_rewards = [] + self.episode_rewards = [] + self.done_episodes = 0 + self.avg_rewards = 0 + self.avg_reward_len = avg_reward_len + self.eps = np.finfo(np.float32).eps.item() + self.batch_states = [] + self.batch_actions = [] + + self.state = self.env.reset() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Passes in a state x through the network and gets the q_values of each action as an output + Args: + x: environment state + Returns: + q values + """ + output = self.net(x) + return output + + def train_batch( + self, + ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]: + """ + Contains the logic for generating a new batch of data to be passed to the DataLoader + Returns: + yields a tuple of Lists containing tensors for states, actions and rewards of the batch. + """ + + while True: + + action = self.agent(self.state, self.device) + + next_state, reward, done, _ = self.env.step(action[0]) + + self.episode_rewards.append(reward) + self.batch_actions.append(action) + self.batch_states.append(self.state) + self.state = next_state + + if done: + self.done_episodes += 1 + self.state = self.env.reset() + self.total_rewards.append(sum(self.episode_rewards)) + self.avg_rewards = float(np.mean(self.total_rewards[-self.avg_reward_len:])) + + returns = self.compute_returns(self.episode_rewards) + + for idx in range(len(self.batch_actions)): + yield self.batch_states[idx], self.batch_actions[idx], returns[idx] + + self.batch_states = [] + self.batch_actions = [] + self.episode_rewards = [] + + def compute_returns(self, rewards): + """ + Calculate the discounted rewards of the batched rewards + + Args: + rewards: list of batched rewards + + Returns: + list of discounted rewards + """ + reward = 0 + returns = [] + + for r in rewards[::-1]: + reward = r + self.gamma * reward + returns.insert(0, reward) + + returns = torch.tensor(returns) + returns = (returns - returns.mean()) / (returns.std() + self.eps) + + return returns + + def loss(self, states, actions, scaled_rewards) -> torch.Tensor: + """ + Calculates the loss for VPG + + Args: + states: batched states + actions: batch actions + scaled_rewards: batche Q values + + Returns: + loss for the current batch + """ + + logits = self.net(states) + + # policy loss + log_prob = log_softmax(logits, dim=1) + log_prob_actions = scaled_rewards * log_prob[range(self.batch_size), actions[0]] + policy_loss = -log_prob_actions.mean() + + # entropy loss + prob = softmax(logits, dim=1) + entropy = -(prob * log_prob).sum(dim=1).mean() + entropy_loss = -self.entropy_beta * entropy + + # total loss + loss = policy_loss + entropy_loss + + return loss + + def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], _) -> OrderedDict: + """ + Carries out a single step through the environment to update the replay buffer. + Then calculates loss based on the minibatch recieved + Args: + batch: current mini batch of replay data + _: batch number, not used + Returns: + Training loss and log metrics + """ + states, actions, scaled_rewards = batch + + loss = self.loss(states, actions, scaled_rewards) + + log = { + "episodes": self.done_episodes, + "reward": self.total_rewards[-1], + "avg_reward": self.avg_rewards, + } + return OrderedDict( + { + "loss": loss, + "avg_reward": self.avg_rewards, + "log": log, + "progress_bar": log, + } + ) + + def configure_optimizers(self) -> List[Optimizer]: + """ Initialize Adam optimizer""" + optimizer = optim.Adam(self.net.parameters(), lr=self.lr) + return [optimizer] + + def _dataloader(self) -> DataLoader: + """Initialize the Replay Buffer dataset used for retrieving experiences""" + dataset = ExperienceSourceDataset(self.train_batch) + dataloader = DataLoader(dataset=dataset, batch_size=self.batch_size) + return dataloader + + def train_dataloader(self) -> DataLoader: + """Get train loader""" + return self._dataloader() + + def get_device(self, batch) -> str: + """Retrieve device currently being used by minibatch""" + return batch[0][0][0].device.index if self.on_gpu else "cpu" + + @staticmethod + def add_model_specific_args(arg_parser) -> argparse.ArgumentParser: + """ + Adds arguments for DQN model + Note: these params are fine tuned for Pong env + Args: + arg_parser: the current argument parser to add to + Returns: + arg_parser with model specific cargs added + """ + + arg_parser.add_argument("--entropy_beta", type=float, default=0.01, help="entropy value") + arg_parser.add_argument("--batches_per_epoch", type=int, default=10000, help="number of batches in an epoch") + arg_parser.add_argument("--batch_size", type=int, default=32, help="size of the batches") + arg_parser.add_argument("--lr", type=float, default=1e-3, help="learning rate") + arg_parser.add_argument("--env", type=str, required=True, help="gym environment tag") + arg_parser.add_argument("--gamma", type=float, default=0.99, help="discount factor") + arg_parser.add_argument("--seed", type=int, default=123, help="seed for training run") + + arg_parser.add_argument( + "--avg_reward_len", + type=int, + default=100, + help="how many episodes to include in avg reward", + ) + + return arg_parser + + +def cli_main(): + parser = argparse.ArgumentParser(add_help=False) + + # trainer args + parser = pl.Trainer.add_argparse_args(parser) + + # model args + parser = VanillaPolicyGradient.add_model_specific_args(parser) + args = parser.parse_args() + + model = VanillaPolicyGradient(**args.__dict__) + + # save checkpoints based on avg_reward + checkpoint_callback = ModelCheckpoint( + save_top_k=1, monitor="avg_reward", mode="max", period=1, verbose=True + ) + + seed_everything(123) + trainer = pl.Trainer.from_argparse_args( + args, deterministic=True, checkpoint_callback=checkpoint_callback + ) + trainer.fit(model) + + +if __name__ == '__main__': + cli_main() diff --git a/pl_bolts/models/vision/__init__.py b/pl_bolts/models/vision/__init__.py index 8d4ec5084e..e525036d34 100644 --- a/pl_bolts/models/vision/__init__.py +++ b/pl_bolts/models/vision/__init__.py @@ -1,2 +1,2 @@ from pl_bolts.models.vision.pixel_cnn import PixelCNN -from pl_bolts.models.vision.unet import UNet \ No newline at end of file +from pl_bolts.models.vision.unet import UNet diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000..e69de29bb2 diff --git a/requirements/models.txt b/requirements/models.txt index 174ab691fc..a92a7ef6bd 100644 --- a/requirements/models.txt +++ b/requirements/models.txt @@ -1,4 +1,5 @@ torchvision>=0.7 scikit-learn>=0.23 Pillow -opencv-python \ No newline at end of file +opencv-python +gym>=0.17.2 # needed for RL \ No newline at end of file diff --git a/requirements/test.txt b/requirements/test.txt index 70b0ce4600..c97d36fc50 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -8,5 +8,4 @@ flake8-black check-manifest twine==1.13.0 -# atari-py==0.2.6 # needed for RL -# gym>=0.17.2 # needed for RL \ No newline at end of file +atari-py==0.2.6 # needed for RL \ No newline at end of file diff --git a/tests/datamodules/test_experience_sources.py b/tests/datamodules/test_experience_sources.py new file mode 100644 index 0000000000..737a1c7150 --- /dev/null +++ b/tests/datamodules/test_experience_sources.py @@ -0,0 +1,321 @@ +from unittest import TestCase +from unittest.mock import Mock + +import gym +import numpy as np +import torch +from torch.utils.data import DataLoader + +from pl_bolts.datamodules.experience_source import ( + BaseExperienceSource, + ExperienceSource, + ExperienceSourceDataset, + Experience, + DiscountedExperienceSource, +) +from pl_bolts.models.rl.common.agents import Agent + + +class DummyAgent(Agent): + def __call__(self, states, device): + return [0] * len(states) + + +class DummyExperienceSource(BaseExperienceSource): + def __iter__(self): + yield torch.ones(3) + + +class TestExperienceSourceDataset(TestCase): + def train_batch(self): + """Returns an iterator used for testing""" + return iter([i for i in range(100)]) + + def test_iterator(self): + """Tests that the iterator returns batches correctly""" + source = ExperienceSourceDataset(self.train_batch) + batch_size = 10 + data_loader = DataLoader(source, batch_size=batch_size) + + for idx, batch in enumerate(data_loader): + self.assertEqual(len(batch), batch_size) + self.assertEqual(batch[0], 0) + self.assertEqual(batch[5], 5) + break + + +class TestBaseExperienceSource(TestCase): + def setUp(self) -> None: + self.net = Mock() + self.agent = DummyAgent(net=self.net) + self.env = gym.make("CartPole-v0") + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.source = DummyExperienceSource(self.env, self.agent) + self.s1 = torch.ones(3) + self.s2 = torch.zeros(3) + + def test_dummy_base_class(self): + """Tests that base class is initialized correctly""" + self.assertTrue(isinstance(self.source.env, gym.Env)) + self.assertTrue(isinstance(self.source.agent, Agent)) + out = next(iter(self.source)) + self.assertTrue(torch.all(out.eq(torch.ones(3)))) + + +class TestExperienceSource(TestCase): + def setUp(self) -> None: + self.net = Mock() + self.agent = DummyAgent(net=self.net) + self.env = [gym.make("CartPole-v0") for _ in range(2)] + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.source = ExperienceSource(self.env, self.agent, n_steps=1) + + self.s1 = torch.ones(3) + self.s2 = torch.zeros(3) + + self.mock_env = Mock() + self.mock_env.step = Mock(return_value=(self.s1, 1, False, Mock())) + + self.exp1 = Experience(state=self.s1, action=1, reward=1, done=False, new_state=self.s2) + self.exp2 = Experience(state=self.s1, action=1, reward=1, done=False, new_state=self.s2) + + def test_init_source(self): + """Test that experience source is setup correctly""" + self.assertEqual(self.source.n_steps, 1) + self.assertIsInstance(self.source.pool, list) + + self.assertEqual(len(self.source.states), len(self.source.pool)) + self.assertEqual(len(self.source.histories), len(self.source.pool)) + self.assertEqual(len(self.source.cur_rewards), len(self.source.pool)) + self.assertEqual(len(self.source.cur_steps), len(self.source.pool)) + + def test_init_single_env(self): + """Test that if a single env is passed that it is wrapped in a list""" + self.source = ExperienceSource(self.mock_env, self.agent) + self.assertIsInstance(self.source.pool, list) + + def test_env_actions(self): + """Assert that a list of actions of shape [num_envs, action_len] is returned""" + actions = self.source.env_actions(self.device) + self.assertEqual(len(actions), len(self.env)) + self.assertTrue(isinstance(actions[0], list)) + + def test_env_step(self): + """Assert that taking a step through a single environment yields a list of history steps""" + actions = [[1], [1]] + env = self.env[0] + exp = self.source.env_step(0, env, actions[0]) + + self.assertTrue(isinstance(exp, Experience)) + + def test_source_next_single_env_single_step(self): + """Test that steps are executed correctly with one environment and 1 step""" + + self.env = [gym.make("CartPole-v0") for _ in range(1)] + self.source = ExperienceSource(self.env, self.agent, n_steps=1) + + for idx, exp in enumerate(self.source.runner(self.device)): + self.assertTrue(isinstance(exp, tuple)) + break + + def test_source_next_single_env_multi_step(self): + """Test that steps are executed correctly with one environment and 2 step""" + + self.env = [gym.make("CartPole-v0") for _ in range(1)] + n_steps = 4 + self.source = ExperienceSource(self.env, self.agent, n_steps=n_steps) + + for idx, exp in enumerate(self.source.runner(self.device)): + self.assertTrue(isinstance(exp, tuple)) + self.assertTrue(len(exp) == n_steps) + break + + def test_source_next_multi_env_single_step(self): + """Test that steps are executed correctly with 2 environment and 1 step""" + + for idx, exp in enumerate(self.source.runner(self.device)): + self.assertTrue(isinstance(exp, tuple)) + self.assertTrue(len(exp) == self.source.n_steps) + break + + def test_source_next_multi_env_multi_step(self): + """Test that steps are executed correctly with 2 environment and 2 step""" + self.source = ExperienceSource(self.env, self.agent, n_steps=2) + + for idx, exp in enumerate(self.source.runner(self.device)): + self.assertTrue(isinstance(exp, tuple)) + self.assertTrue(len(exp) == self.source.n_steps) + break + + def test_source_update_state(self): + """Test that after a step the state is updated""" + + self.env = [gym.make("CartPole-v0") for _ in range(1)] + self.source = ExperienceSource(self.env, self.agent, n_steps=2) + + for idx, exp in enumerate(self.source.runner(self.device)): + self.assertTrue(isinstance(exp, tuple)) + new = np.asarray(exp[-1].new_state) + old = np.asarray(self.source.states[0]) + self.assertTrue(np.array_equal(new, old)) + break + + def test_source_is_done_short_episode(self): + """Test that when done and the history is not full, to return the partial history""" + + self.mock_env.step = Mock(return_value=(self.s1, 1, True, Mock)) + + env = [self.mock_env for _ in range(1)] + self.source = ExperienceSource(env, self.agent, n_steps=2) + + for idx, exp in enumerate(self.source.runner(self.device)): + self.assertTrue(isinstance(exp, tuple)) + self.assertTrue(len(exp) == 1) + break + + def test_source_is_done_2step_episode(self): + """ + Test that when done and the history is full, return the full history, then start to return the tail of + the history + """ + + self.env = [self.mock_env] + self.source = ExperienceSource(self.env, self.agent, n_steps=2) + + self.mock_env.step = Mock(return_value=(self.s1, 1, True, Mock)) + + self.source.histories[0].append(self.exp1) + + for idx, exp in enumerate(self.source.runner(self.device)): + + self.assertTrue(isinstance(exp, tuple)) + + if idx == 0: + self.assertTrue(len(exp) == self.source.n_steps) + elif idx == 1: + self.assertTrue(len(exp) == self.source.n_steps - 1) + self.assertTrue(torch.equal(exp[0].new_state, self.s1)) + + break + + def test_source_is_done_metrics(self): + """Test that when done and the history is full, return the full history""" + + n_steps = 3 + n_envs = 2 + + self.mock_env.step = Mock(return_value=(self.s1, 1, True, Mock)) + + self.env = [self.mock_env for _ in range(2)] + self.source = ExperienceSource(self.env, self.agent, n_steps=3) + + history = self.source.histories[0] + history += [self.exp1, self.exp2, self.exp2] + + for idx, exp in enumerate(self.source.runner(self.device)): + + if idx == n_steps - 1: + self.assertEqual(self.source._total_rewards[0], 1) + self.assertEqual(self.source.total_steps[0], 1) + self.assertEqual(self.source.cur_rewards[0], 0) + self.assertEqual(self.source.cur_steps[0], 0) + elif idx == (3 * n_envs) - 1: + self.assertEqual(self.source.iter_idx, 1) + break + + def test_pop_total_rewards(self): + """Test that pop rewards returns correct rewards""" + self.source._total_rewards = [10, 20, 30] + + rewards = self.source.pop_total_rewards() + + self.assertEqual(rewards, [10, 20, 30]) + + +class TestDiscountedExperienceSource(TestCase): + def setUp(self) -> None: + self.net = Mock() + self.agent = DummyAgent(net=self.net) + self.env = [gym.make("CartPole-v0") for _ in range(2)] + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.n_steps = 3 + self.gamma = 0.9 + self.source = DiscountedExperienceSource( + self.env, self.agent, n_steps=self.n_steps, gamma=self.gamma + ) + + self.state = torch.ones(3) + self.next_state = torch.zeros(3) + self.reward = 1 + + self.exp1 = Experience( + state=self.state, + action=1, + reward=self.reward, + done=False, + new_state=self.next_state, + ) + self.exp2 = Experience( + state=self.next_state, + action=1, + reward=self.reward, + done=False, + new_state=self.state, + ) + + self.env1 = Mock() + self.env1.step = Mock( + return_value=(self.next_state, self.reward, True, self.state) + ) + + def test_init(self): + """Test that experience source is setup correctly""" + self.assertEqual(self.source.n_steps, self.n_steps + 1) + self.assertEqual(self.source.steps, self.n_steps) + self.assertEqual(self.source.gamma, self.gamma) + + def test_source_step(self): + """Tests that the source returns a single experience""" + + for idx, exp in enumerate(self.source.runner(self.device)): + self.assertTrue(isinstance(exp, Experience)) + break + + def test_source_step_done(self): + """Tests that the source returns a single experience""" + + self.source = DiscountedExperienceSource( + self.env1, self.agent, n_steps=self.n_steps + ) + + self.source.histories[0].append(self.exp1) + self.source.histories[0].append(self.exp2) + + for idx, exp in enumerate(self.source.runner(self.device)): + self.assertTrue(isinstance(exp, Experience)) + self.assertTrue(torch.all(torch.eq(exp.new_state, self.next_state))) + break + + def test_source_discounted_return(self): + """ + Tests that the source returns a single experience with discounted rewards + + discounted returns: G(t) = R(t+1) + γ*R(t+2) + γ^2*R(t+3) ... + γ^N-1*R(t+N) + """ + + self.source = DiscountedExperienceSource( + self.env1, self.agent, n_steps=self.n_steps + ) + + self.source.histories[0] += [self.exp1, self.exp2] + + discounted_reward = ( + self.exp1.reward + + (self.source.gamma * self.exp2.reward) + + (self.source.gamma * self.reward) ** 2 + ) + + for idx, exp in enumerate(self.source.runner(self.device)): + self.assertTrue(isinstance(exp, Experience)) + self.assertEqual(exp.reward, discounted_reward) + break diff --git a/tests/losses/test_rl_loss.py b/tests/losses/test_rl_loss.py new file mode 100644 index 0000000000..e02965f84c --- /dev/null +++ b/tests/losses/test_rl_loss.py @@ -0,0 +1,51 @@ +""" +Test RL Loss Functions +""" + +from unittest import TestCase + +import numpy as np +import torch + +from pl_bolts.losses.rl import dqn_loss, double_dqn_loss, per_dqn_loss +from pl_bolts.models.rl.common.networks import CNN +from pl_bolts.models.rl.common.gym_wrappers import make_environment + + +class TestRLLoss(TestCase): + + def setUp(self) -> None: + + self.state = torch.rand(32, 4, 84, 84) + self.next_state = torch.rand(32, 4, 84, 84) + self.action = torch.ones([32]) + self.reward = torch.ones([32]) + self.done = torch.zeros([32]).long() + + self.batch = (self.state, self.action, self.reward, self.done, self.next_state) + + self.env = make_environment("PongNoFrameskip-v4") + self.obs_shape = self.env.observation_space.shape + self.n_actions = self.env.action_space.n + self.net = CNN(self.obs_shape, self.n_actions) + self.target_net = CNN(self.obs_shape, self.n_actions) + + def test_dqn_loss(self): + """Test the dqn loss function""" + + loss = dqn_loss(self.batch, self.net, self.target_net) + self.assertIsInstance(loss, torch.Tensor) + + def test_double_dqn_loss(self): + """Test the double dqn loss function""" + + loss = double_dqn_loss(self.batch, self.net, self.target_net) + self.assertIsInstance(loss, torch.Tensor) + + def test_per_dqn_loss(self): + """Test the double dqn loss function""" + prios = torch.ones([32]) + + loss, batch_weights = per_dqn_loss(self.batch, prios, self.net, self.target_net) + self.assertIsInstance(loss, torch.Tensor) + self.assertIsInstance(batch_weights, np.ndarray) diff --git a/tests/models/rl/__init__.py b/tests/models/rl/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/models/rl/integration/__init__.py b/tests/models/rl/integration/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/models/rl/integration/test_policy_models.py b/tests/models/rl/integration/test_policy_models.py new file mode 100644 index 0000000000..3c65af9d2e --- /dev/null +++ b/tests/models/rl/integration/test_policy_models.py @@ -0,0 +1,41 @@ +import argparse +from unittest import TestCase + +import pytorch_lightning as pl + +from pl_bolts.models.rl.reinforce_model import Reinforce +from pl_bolts.models.rl.vanilla_policy_gradient_model import VanillaPolicyGradient + + +class TestPolicyModels(TestCase): + + def setUp(self) -> None: + parent_parser = argparse.ArgumentParser(add_help=False) + parent_parser = VanillaPolicyGradient.add_model_specific_args(parent_parser) + args_list = [ + "--env", "CartPole-v0" + ] + self.hparams = parent_parser.parse_args(args_list) + + self.trainer = pl.Trainer( + gpus=0, + max_steps=100, + max_epochs=100, # Set this as the same as max steps to ensure that it doesn't stop early + val_check_interval=1, # This just needs 'some' value, does not effect training right now + fast_dev_run=True + ) + + def test_reinforce(self): + """Smoke test that the reinforce model runs""" + + model = Reinforce(self.hparams.env) + result = self.trainer.fit(model) + + self.assertEqual(result, 1) + + def test_policy_gradient(self): + """Smoke test that the policy gradient model runs""" + model = VanillaPolicyGradient(self.hparams.env) + result = self.trainer.fit(model) + + self.assertEqual(result, 1) diff --git a/tests/models/rl/integration/test_value_models.py b/tests/models/rl/integration/test_value_models.py new file mode 100644 index 0000000000..f3cbad43ad --- /dev/null +++ b/tests/models/rl/integration/test_value_models.py @@ -0,0 +1,74 @@ +import argparse +from unittest import TestCase + +import pytorch_lightning as pl + +from pl_bolts.models.rl.double_dqn_model import DoubleDQN +from pl_bolts.models.rl.dqn_model import DQN +from pl_bolts.models.rl.dueling_dqn_model import DuelingDQN +from pl_bolts.models.rl.noisy_dqn_model import NoisyDQN +from pl_bolts.models.rl.per_dqn_model import PERDQN + + +class TestValueModels(TestCase): + + def setUp(self) -> None: + parent_parser = argparse.ArgumentParser(add_help=False) + parent_parser = pl.Trainer.add_argparse_args(parent_parser) + parent_parser = DQN.add_model_specific_args(parent_parser) + args_list = [ + "--warm_start_size", "100", + "--gpus", "0", + "--env", "PongNoFrameskip-v4", + ] + self.hparams = parent_parser.parse_args(args_list) + + self.trainer = pl.Trainer( + gpus=self.hparams.gpus, + max_steps=100, + max_epochs=100, # Set this as the same as max steps to ensure that it doesn't stop early + val_check_interval=1, # This just needs 'some' value, does not effect training right now + fast_dev_run=True + ) + + def test_dqn(self): + """Smoke test that the DQN model runs""" + model = DQN(self.hparams.env, num_envs=5) + result = self.trainer.fit(model) + + self.assertEqual(result, 1) + + def test_double_dqn(self): + """Smoke test that the Double DQN model runs""" + model = DoubleDQN(self.hparams.env) + result = self.trainer.fit(model) + + self.assertEqual(result, 1) + + def test_dueling_dqn(self): + """Smoke test that the Dueling DQN model runs""" + model = DuelingDQN(self.hparams.env) + result = self.trainer.fit(model) + + self.assertEqual(result, 1) + + def test_noisy_dqn(self): + """Smoke test that the Noisy DQN model runs""" + model = NoisyDQN(self.hparams.env) + result = self.trainer.fit(model) + + self.assertEqual(result, 1) + + def test_per_dqn(self): + """Smoke test that the PER DQN model runs""" + model = PERDQN(self.hparams.env) + result = self.trainer.fit(model) + + self.assertEqual(result, 1) + + # def test_n_step_dqn(self): + # """Smoke test that the N Step DQN model runs""" + # model = DQN(self.hparams.env, n_steps=self.hparams.n_steps) + # result = self.trainer.fit(model) + # + # self.assertEqual(result, 1) diff --git a/tests/models/rl/test_scripts.py b/tests/models/rl/test_scripts.py new file mode 100644 index 0000000000..af1d703897 --- /dev/null +++ b/tests/models/rl/test_scripts.py @@ -0,0 +1,104 @@ +from unittest import mock + +import pytest + + +@pytest.mark.parametrize('cli_args', ['--env PongNoFrameskip-v4' + ' --max_steps 10' + ' --fast_dev_run' + ' --warm_start_size 10' + ' --n_steps 2' + ' --batch_size 10']) +def test_cli_run_rl_dqn(cli_args): + """Test running CLI for an example with default params.""" + from pl_bolts.models.rl.dqn_model import cli_main + + cli_args = cli_args.split(' ') if cli_args else [] + with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): + cli_main() + + +@pytest.mark.parametrize('cli_args', ['--env PongNoFrameskip-v4' + ' --max_steps 10' + ' --fast_dev_run' + ' --warm_start_size 10' + ' --n_steps 2' + ' --batch_size 10']) +def test_cli_run_rl_double_dqn(cli_args): + """Test running CLI for an example with default params.""" + from pl_bolts.models.rl.double_dqn_model import cli_main + + cli_args = cli_args.split(' ') if cli_args else [] + with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): + cli_main() + + +@pytest.mark.parametrize('cli_args', ['--env PongNoFrameskip-v4' + ' --max_steps 10' + ' --fast_dev_run' + ' --warm_start_size 10' + ' --n_steps 2' + ' --batch_size 10']) +def test_cli_run_rl_dueling_dqn(cli_args): + """Test running CLI for an example with default params.""" + from pl_bolts.models.rl.dueling_dqn_model import cli_main + + cli_args = cli_args.split(' ') if cli_args else [] + with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): + cli_main() + + +@pytest.mark.parametrize('cli_args', ['--env PongNoFrameskip-v4' + ' --max_steps 10' + ' --fast_dev_run' + ' --warm_start_size 10' + ' --n_steps 2' + ' --batch_size 10']) +def test_cli_run_rl_noisy_dqn(cli_args): + """Test running CLI for an example with default params.""" + from pl_bolts.models.rl.noisy_dqn_model import cli_main + + cli_args = cli_args.split(' ') if cli_args else [] + with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): + cli_main() + + +@pytest.mark.parametrize('cli_args', ['--env PongNoFrameskip-v4' + ' --max_steps 10' + ' --fast_dev_run' + ' --warm_start_size 10' + ' --n_steps 2' + ' --batch_size 10']) +def test_cli_run_rl_per_dqn(cli_args): + """Test running CLI for an example with default params.""" + from pl_bolts.models.rl.per_dqn_model import cli_main + + cli_args = cli_args.split(' ') if cli_args else [] + with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): + cli_main() + + +@pytest.mark.parametrize('cli_args', ['--env CartPole-v0' + ' --max_steps 10' + ' --fast_dev_run' + ' --batch_size 10']) +def test_cli_run_rl_reinforce(cli_args): + """Test running CLI for an example with default params.""" + from pl_bolts.models.rl.reinforce_model import cli_main + + cli_args = cli_args.split(' ') if cli_args else [] + with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): + cli_main() + + +@pytest.mark.parametrize('cli_args', ['--env CartPole-v0' + ' --max_steps 10' + ' --fast_dev_run' + ' --batch_size 10']) +def test_cli_run_rl_vanilla_policy_gradient(cli_args): + """Test running CLI for an example with default params.""" + from pl_bolts.models.rl.vanilla_policy_gradient_model import cli_main + + cli_args = cli_args.split(' ') if cli_args else [] + with mock.patch("argparse._sys.argv", ["any.py"] + cli_args): + cli_main() diff --git a/tests/models/rl/unit/__init__.py b/tests/models/rl/unit/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/models/rl/unit/test_agents.py b/tests/models/rl/unit/test_agents.py new file mode 100644 index 0000000000..5d4214e59b --- /dev/null +++ b/tests/models/rl/unit/test_agents.py @@ -0,0 +1,62 @@ +"""Tests that the agent module works correctly""" +from unittest import TestCase +from unittest.mock import Mock + +import gym +import numpy as np +import torch + +from pl_bolts.models.rl.common.agents import Agent, PolicyAgent, ValueAgent + + +class TestAgents(TestCase): + + def setUp(self) -> None: + self.env = gym.make("CartPole-v0") + self.state = self.env.reset() + self.net = Mock() + + def test_base_agent(self): + agent = Agent(self.net) + action = agent(self.state, 'cuda:0') + self.assertIsInstance(action, list) + + +class TestValueAgent(TestCase): + + def setUp(self) -> None: + self.env = gym.make("CartPole-v0") + self.net = Mock(return_value=torch.Tensor([[0.0, 100.0]])) + self.state = [self.env.reset()] + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.value_agent = ValueAgent(self.net, self.env.action_space.n) + + def test_value_agent(self): + + action = self.value_agent(self.state, self.device) + self.assertIsInstance(action, list) + self.assertIsInstance(action[0], int) + + def test_value_agent_get_action(self): + action = self.value_agent.get_action(self.state, self.device) + self.assertIsInstance(action, np.ndarray) + self.assertEqual(action[0], 1) + + def test_value_agent_random(self): + action = self.value_agent.get_random_action(self.state) + self.assertIsInstance(action[0], int) + + +class TestPolicyAgent(TestCase): + + def setUp(self) -> None: + self.env = gym.make("CartPole-v0") + self.net = Mock(return_value=torch.Tensor([[0.0, 100.0]])) + self.states = [self.env.reset()] + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + def test_policy_agent(self): + policy_agent = PolicyAgent(self.net) + action = policy_agent(self.states, self.device) + self.assertIsInstance(action, list) + self.assertEqual(action[0], 1) diff --git a/tests/models/rl/unit/test_memory.py b/tests/models/rl/unit/test_memory.py new file mode 100644 index 0000000000..12b89b232e --- /dev/null +++ b/tests/models/rl/unit/test_memory.py @@ -0,0 +1,286 @@ +from unittest import TestCase +from unittest.mock import Mock + +import numpy as np +import torch + +from pl_bolts.models.rl.common.memory import ReplayBuffer, Experience, PERBuffer, MultiStepBuffer, Buffer + + +class TestBuffer(TestCase): + + def train_batch(self): + """Returns an iterator used for testing""" + return iter([i for i in range(100)]) + + def setUp(self) -> None: + self.state = np.random.rand(4, 84, 84) + self.next_state = np.random.rand(4, 84, 84) + self.action = np.ones([1]) + self.reward = np.ones([1]) + self.done = np.zeros([1]) + self.experience = Experience(self.state, self.action, self.reward, self.done, self.next_state) + self.source = Mock() + self.source.step = Mock(return_value=(self.experience, torch.tensor(0), False)) + self.batch_size = 8 + self.buffer = Buffer(8) + + for _ in range(self.batch_size): + self.buffer.append(self.experience) + + def test_sample_batch(self): + """check that a sinlge sample is returned""" + sample = self.buffer.sample() + self.assertEqual(len(sample), 5) + self.assertEqual(sample[0].shape, (self.batch_size, 4, 84, 84)) + self.assertEqual(sample[1].shape, (self.batch_size, 1)) + self.assertEqual(sample[2].shape, (self.batch_size, 1)) + self.assertEqual(sample[3].shape, (self.batch_size, 1)) + self.assertEqual(sample[4].shape, (self.batch_size, 4, 84, 84)) + + +class TestReplayBuffer(TestCase): + + def setUp(self) -> None: + self.state = np.random.rand(32, 32) + self.next_state = np.random.rand(32, 32) + self.action = np.ones([1]) + self.reward = np.ones([1]) + self.done = np.zeros([1]) + self.experience = Experience(self.state, self.action, self.reward, self.done, self.next_state) + + self.source = Mock() + self.source.step = Mock(return_value=(self.experience, torch.tensor(0), False)) + self.warm_start = 10 + self.buffer = ReplayBuffer(20) + for _ in range(self.warm_start): + self.buffer.append(self.experience) + + def test_replay_buffer_append(self): + """Test that you can append to the replay buffer""" + + self.assertEqual(len(self.buffer), self.warm_start) + + self.buffer.append(self.experience) + + self.assertEqual(len(self.buffer), self.warm_start + 1) + + def test_replay_buffer_populate(self): + """Tests that the buffer is populated correctly with warm_start""" + self.assertEqual(len(self.buffer.buffer), self.warm_start) + + def test_replay_buffer_update(self): + """Tests that buffer append works correctly""" + batch_size = 3 + self.assertEqual(len(self.buffer.buffer), self.warm_start) + for i in range(batch_size): + self.buffer.append(self.experience) + self.assertEqual(len(self.buffer.buffer), self.warm_start + batch_size) + + def test_replay_buffer_sample(self): + """Test that you can sample from the buffer and the outputs are the correct shape""" + batch_size = 3 + + for i in range(10): + self.buffer.append(self.experience) + + batch = self.buffer.sample(batch_size) + + self.assertEqual(len(batch), 5) + + # states + states = batch[0] + self.assertEqual(states.shape, (batch_size, 32, 32)) + # action + actions = batch[1] + self.assertEqual(actions.shape, (batch_size, 1)) + # reward + rewards = batch[2] + self.assertEqual(rewards.shape, (batch_size, 1)) + # dones + dones = batch[3] + self.assertEqual(dones.shape, (batch_size, 1)) + # next states + next_states = batch[4] + self.assertEqual(next_states.shape, (batch_size, 32, 32)) + + +class TestPrioReplayBuffer(TestCase): + + def setUp(self) -> None: + self.buffer = PERBuffer(10) + + self.state = np.random.rand(32, 32) + self.next_state = np.random.rand(32, 32) + self.action = np.ones([1]) + self.reward = np.ones([1]) + self.done = np.zeros([1]) + self.experience = Experience(self.state, self.action, self.reward, self.done, self.next_state) + + def test_replay_buffer_append(self): + """Test that you can append to the replay buffer and the latest experience has max priority""" + + self.assertEqual(len(self.buffer), 0) + + self.buffer.append(self.experience) + + self.assertEqual(len(self.buffer), 1) + self.assertEqual(self.buffer.priorities[0], 1.0) + + def test_replay_buffer_sample(self): + """Test that you can sample from the buffer and the outputs are the correct shape""" + batch_size = 3 + + for i in range(10): + self.buffer.append(self.experience) + + batch, indices, weights = self.buffer.sample(batch_size) + + self.assertEqual(len(batch), 5) + self.assertEqual(len(indices), batch_size) + self.assertEqual(len(weights), batch_size) + + # states + states = batch[0] + self.assertEqual(states.shape, (batch_size, 32, 32)) + # action + actions = batch[1] + self.assertEqual(actions.shape, (batch_size, 1)) + # reward + rewards = batch[2] + self.assertEqual(rewards.shape, (batch_size, 1)) + # dones + dones = batch[3] + self.assertEqual(dones.shape, (batch_size, 1)) + # next states + next_states = batch[4] + self.assertEqual(next_states.shape, (batch_size, 32, 32)) + + +class TestMultiStepReplayBuffer(TestCase): + + def setUp(self) -> None: + self.gamma = 0.9 + self.buffer = MultiStepBuffer(capacity=10, n_steps=2, gamma=self.gamma) + + self.state = np.zeros([32, 32]) + self.state_02 = np.ones([32, 32]) + self.next_state = np.zeros([32, 32]) + self.next_state_02 = np.ones([32, 32]) + self.action = np.zeros([1]) + self.action_02 = np.ones([1]) + self.reward = np.zeros([1]) + self.reward_02 = np.ones([1]) + self.done = np.zeros([1]) + self.done_02 = np.zeros([1]) + + self.experience01 = Experience(self.state, self.action, self.reward, self.done, self.next_state) + self.experience02 = Experience(self.state_02, self.action_02, self.reward_02, self.done_02, self.next_state_02) + self.experience03 = Experience(self.state_02, self.action_02, self.reward_02, self.done_02, self.next_state_02) + + def test_append_single_experience_less_than_n(self): + """ + If a single experience is added and n > 1 nothing should be added to the buffer as it is waiting experiences + to equal n + """ + self.assertEqual(len(self.buffer), 0) + + self.buffer.append(self.experience01) + + self.assertEqual(len(self.buffer), 0) + + def test_append_single_experience(self): + """ + If a single experience is added and n > 1 nothing should be added to the buffer as it is waiting experiences + to equal n + """ + self.assertEqual(len(self.buffer), 0) + + self.buffer.append(self.experience01) + + self.assertEqual(len(self.buffer.exp_history_queue), 0) + self.assertEqual(len(self.buffer.history), 1) + + def test_append_single_experience2(self): + """ + If a single experience is added and the number of experiences collected >= n, the multi step experience should + be added to the full buffer. + """ + self.assertEqual(len(self.buffer), 0) + + self.buffer.append(self.experience01) + self.buffer.append(self.experience02) + + self.assertEqual(len(self.buffer.buffer), 1) + self.assertEqual(len(self.buffer.history), 2) + + def test_sample_single_experience(self): + """if there is only a single experience added, sample should return nothing""" + self.buffer.append(self.experience01) + + with self.assertRaises(Exception) as context: + _ = self.buffer.sample(batch_size=1) + + self.assertIsInstance(context.exception, Exception) + + def test_sample_multi_experience(self): + """if there is only a single experience added, sample should return nothing""" + self.buffer.append(self.experience01) + self.buffer.append(self.experience02) + + batch = self.buffer.sample(batch_size=1) + + next_state = batch[4] + self.assertEqual(next_state.all(), self.next_state_02.all()) + + def test_get_transition_info_2_step(self): + """Test that the accumulated experience is correct and""" + self.buffer.append(self.experience01) + self.buffer.append(self.experience02) + + reward = self.buffer.buffer[0].reward + next_state = self.buffer.buffer[0].new_state + done = self.buffer.buffer[0].done + + reward_gt = self.experience01.reward + (self.gamma * self.experience02.reward) * (1 - done) + + self.assertEqual(reward, reward_gt) + self.assertEqual(next_state.all(), self.next_state_02.all()) + self.assertEqual(self.experience02.done, done) + + def test_get_transition_info_3_step(self): + """Test that the accumulated experience is correct with multi step""" + self.buffer = MultiStepBuffer(capacity=10, n_steps=3, gamma=self.gamma) + + self.buffer.append(self.experience01) + self.buffer.append(self.experience02) + self.buffer.append(self.experience02) + + reward = self.buffer.buffer[0].reward + next_state = self.buffer.buffer[0].new_state + done = self.buffer.buffer[0].done + + reward_01 = self.experience02.reward + self.gamma * self.experience03.reward * (1 - done) + reward_gt = self.experience01.reward + self.gamma * reward_01 * (1 - done) + + self.assertEqual(reward, reward_gt) + self.assertEqual(next_state.all(), self.next_state_02.all()) + self.assertEqual(self.experience03.done, done) + + def test_sample_3_step(self): + """Test that final output of the 3 step sample is correct""" + self.buffer = MultiStepBuffer(capacity=10, n_steps=3, gamma=self.gamma) + + self.buffer.append(self.experience01) + self.buffer.append(self.experience02) + self.buffer.append(self.experience02) + + reward_gt = 1.71 + + batch = self.buffer.sample(1) + + self.assertEqual(batch[0].all(), self.experience01.state.all()) + self.assertEqual(batch[1], self.experience01.action) + self.assertEqual(batch[2], reward_gt) + self.assertEqual(batch[3], self.experience02.done) + self.assertEqual(batch[4].all(), self.experience02.new_state.all()) diff --git a/tests/models/rl/unit/test_reinforce.py b/tests/models/rl/unit/test_reinforce.py new file mode 100644 index 0000000000..655dc2bd54 --- /dev/null +++ b/tests/models/rl/unit/test_reinforce.py @@ -0,0 +1,65 @@ +import argparse +from unittest import TestCase + +import gym +import numpy as np +import torch + +from pl_bolts.datamodules.experience_source import DiscountedExperienceSource +from pl_bolts.models.rl.common.agents import Agent +from pl_bolts.models.rl.common.networks import MLP +from pl_bolts.models.rl.common.gym_wrappers import ToTensor +from pl_bolts.models.rl.reinforce_model import Reinforce + + +class TestReinforce(TestCase): + + def setUp(self) -> None: + self.env = ToTensor(gym.make("CartPole-v0")) + self.obs_shape = self.env.observation_space.shape + self.n_actions = self.env.action_space.n + self.net = MLP(self.obs_shape, self.n_actions) + self.agent = Agent(self.net) + self.exp_source = DiscountedExperienceSource(self.env, self.agent) + + parent_parser = argparse.ArgumentParser(add_help=False) + parent_parser = Reinforce.add_model_specific_args(parent_parser) + args_list = [ + "--env", "CartPole-v0", + "--batch_size", "32", + "--gamma", "0.99" + ] + self.hparams = parent_parser.parse_args(args_list) + self.model = Reinforce(**vars(self.hparams)) + + self.rl_dataloader = self.model.train_dataloader() + + def test_loss(self): + """Test the reinforce loss function""" + + batch_states = torch.rand(32, 4) + batch_actions = torch.rand(32).long() + batch_qvals = torch.rand(32) + + loss = self.model.loss(batch_states, batch_actions, batch_qvals) + + self.assertIsInstance(loss, torch.Tensor) + + def test_get_qvals(self): + """Test that given an batch of episodes that it will return a list of qvals for each episode""" + + batch_qvals = [] + rewards = np.ones(32) + out = self.model.calc_qvals(rewards) + batch_qvals.append(out) + + self.assertIsInstance(batch_qvals[0][0], float) + self.assertEqual(batch_qvals[0][0], (batch_qvals[0][1] * self.hparams.gamma) + 1.0) + + def test_calc_q_vals(self): + rewards = np.ones(4) + gt_qvals = [3.9403989999999998, 2.9701, 1.99, 1.0] + + qvals = self.model.calc_qvals(rewards) + + self.assertEqual(gt_qvals, qvals) diff --git a/tests/models/rl/unit/test_vpg.py b/tests/models/rl/unit/test_vpg.py new file mode 100644 index 0000000000..0cbdb5a7c8 --- /dev/null +++ b/tests/models/rl/unit/test_vpg.py @@ -0,0 +1,56 @@ +import argparse +from unittest import TestCase + +import gym +import torch + +from pl_bolts.models.rl.common.agents import Agent +from pl_bolts.models.rl.common.networks import MLP +from pl_bolts.models.rl.common.gym_wrappers import ToTensor +from pl_bolts.models.rl.vanilla_policy_gradient_model import VanillaPolicyGradient + + +class TestPolicyGradient(TestCase): + + def setUp(self) -> None: + self.env = ToTensor(gym.make("CartPole-v0")) + self.obs_shape = self.env.observation_space.shape + self.n_actions = self.env.action_space.n + self.net = MLP(self.obs_shape, self.n_actions) + self.agent = Agent(self.net) + + parent_parser = argparse.ArgumentParser(add_help=False) + parent_parser = VanillaPolicyGradient.add_model_specific_args(parent_parser) + args_list = [ + "--env", "CartPole-v0", + "--batch_size", "32" + ] + self.hparams = parent_parser.parse_args(args_list) + self.model = VanillaPolicyGradient(**vars(self.hparams)) + + def test_loss(self): + """Test the reinforce loss function""" + + batch_states = torch.rand(32, 4) + batch_actions = torch.rand(32).long() + batch_qvals = torch.rand(32) + + loss = self.model.loss(batch_states, batch_actions, batch_qvals) + + self.assertIsInstance(loss, torch.Tensor) + + def test_train_batch(self): + """Tests that a single batch generates correctly""" + + self.model.n_steps = 4 + self.model.batch_size = 1 + xp_dataloader = self.model.train_dataloader() + + batch = next(iter(xp_dataloader)) + self.assertEqual(len(batch), 3) + self.assertEqual(len(batch[0]), self.model.batch_size) + self.assertTrue(isinstance(batch, list)) + self.assertIsInstance(batch[0], torch.Tensor) + self.assertIsInstance(batch[1], list) + self.assertIsInstance(batch[1][0], torch.Tensor) + self.assertIsInstance(batch[2], torch.Tensor) diff --git a/tests/models/rl/unit/test_wrappers.py b/tests/models/rl/unit/test_wrappers.py new file mode 100644 index 0000000000..31e84ada49 --- /dev/null +++ b/tests/models/rl/unit/test_wrappers.py @@ -0,0 +1,19 @@ +from unittest import TestCase + +import gym +import torch + +from pl_bolts.models.rl.common.gym_wrappers import ToTensor + + +class TestToTensor(TestCase): + + def setUp(self) -> None: + self.env = ToTensor(gym.make("CartPole-v0")) + + def test_wrapper(self): + state = self.env.reset() + self.assertIsInstance(state, torch.Tensor) + + new_state, _, _, _ = self.env.step(1) + self.assertIsInstance(new_state, torch.Tensor) diff --git a/tests/models/test_mnist_templates.py b/tests/models/test_mnist_templates.py index 7099212cb2..0c8867eb03 100644 --- a/tests/models/test_mnist_templates.py +++ b/tests/models/test_mnist_templates.py @@ -7,11 +7,11 @@ def test_mnist(tmpdir): seed_everything() - model = LitMNIST(data_dir=tmpdir) + model = LitMNIST(data_dir=tmpdir, num_workers=0) trainer = pl.Trainer(limit_train_batches=0.01, limit_val_batches=0.01, max_epochs=1, limit_test_batches=0.01, default_root_dir=tmpdir) trainer.fit(model) trainer.test(model) - loss = trainer.callback_metrics['loss'] + loss = trainer.callback_metrics['train_loss'] - assert loss <= 2.0, 'mnist failed' + assert loss <= 2.2, 'mnist failed' diff --git a/tests/models/test_vision_models.py b/tests/models/test_vision_models.py index 0455a76320..73af207f1a 100644 --- a/tests/models/test_vision_models.py +++ b/tests/models/test_vision_models.py @@ -4,6 +4,7 @@ from pl_bolts.datamodules import MNISTDataModule, FashionMNISTDataModule from pl_bolts.models import GPT2, ImageGPT, UNet + def test_igpt(tmpdir): pl.seed_everything(0) dm = MNISTDataModule(tmpdir, normalize=False) @@ -53,4 +54,3 @@ def test_unet(tmpdir): model = UNet(num_classes=2) y = model(x) assert y.shape == torch.Size([10, 2, 28, 28]) - From eac58393631d6a2ba7d736f13cdeae8b3fd3d9df Mon Sep 17 00:00:00 2001 From: Caroline Adams <95carolineelizabeth@gmail.com> Date: Wed, 7 Oct 2020 16:08:31 -0400 Subject: [PATCH 05/32] Neatened up Bolts Documentation (#262) * Documentation clean up Looked for grammatical and markdown errors mostly * Bolts grammatical changes --- docs/source/classic_ml.rst | 4 +- docs/source/dataloaders.rst | 5 +- docs/source/datamodules.rst | 4 +- docs/source/introduction_guide.rst | 14 ++-- docs/source/models.rst | 4 +- docs/source/reinforce_learn.rst | 67 +++++++++---------- pl_bolts/models/rl/dqn_model.py | 3 +- pl_bolts/models/rl/per_dqn_model.py | 3 +- .../self_supervised/byol/byol_module.py | 2 +- 9 files changed, 53 insertions(+), 53 deletions(-) diff --git a/docs/source/classic_ml.rst b/docs/source/classic_ml.rst index d3b3c39712..8a1f8b3aa7 100644 --- a/docs/source/classic_ml.rst +++ b/docs/source/classic_ml.rst @@ -9,7 +9,7 @@ half-precision training. Linear Regression ----------------- Linear regression fits a linear model between a real-valued target variable :math:`y` and one or more features :math:`X`. We -estimate the regression coefficients that minimizes the mean squared error between the predicted and true target +estimate the regression coefficients that minimize the mean squared error between the predicted and true target values. We formulate the linear regression model as a single-layer neural network. By default we include only one neuron in @@ -69,7 +69,7 @@ Add either L1 or L2 regularization, or both, by specifying the regularization st trainer.test(test_dataloaders=dm.test_dataloader(batch_size=12)) -Any input will be flattened across all dimensions except the firs one (batch). +Any input will be flattened across all dimensions except the first one (batch). This means images, sound, etc... work out of the box. .. code-block:: python diff --git a/docs/source/dataloaders.rst b/docs/source/dataloaders.rst index efe932027b..70aab6ae78 100644 --- a/docs/source/dataloaders.rst +++ b/docs/source/dataloaders.rst @@ -3,7 +3,10 @@ AsynchronousLoader This dataloader behaves identically to the standard pytorch dataloader, but will transfer data asynchronously to the GPU with training. You can also use it to wrap an existing dataloader. -Example:: +Example: + +.. code-block:: python + dataloader = AsynchronousLoader(DataLoader(ds, batch_size=16), device=device) for b in dataloader: diff --git a/docs/source/datamodules.rst b/docs/source/datamodules.rst index 6468326e5c..94c7fc28d8 100644 --- a/docs/source/datamodules.rst +++ b/docs/source/datamodules.rst @@ -7,9 +7,9 @@ DataModules (introduced in PyTorch Lightning 0.9.0) decouple the data from a mod is simply a collection of a training dataloder, val dataloader and test dataloader. In addition, it specifies how to: -- Downloading/preparing data. +- Download/prepare data. - Train/val/test splits. -- Transforms +- Transform Then you can use it like this: diff --git a/docs/source/introduction_guide.rst b/docs/source/introduction_guide.rst index 2ff923a911..a16ba08818 100644 --- a/docs/source/introduction_guide.rst +++ b/docs/source/introduction_guide.rst @@ -10,7 +10,7 @@ Bolts is a Deep learning research and production toolbox of: - Losses. - Datasets. -**The Main goal of bolts is to enable trying new ideas as fast as possible!** +**The Main goal of Bolts is to enable trying new ideas as fast as possible!** All models are tested (daily), benchmarked, documented and work on CPUs, TPUs, GPUs and 16-bit precision. @@ -90,11 +90,11 @@ All models are tested (daily), benchmarked, documented and work on CPUs, TPUs, G Community Built --------------- -Bolts are built-by the Lightning community and contributed to bolts. +Then lightning community builds bolts and contributes them to Bolts. The lightning team guarantees that contributions are: -1. Rigorously Tested (CPUs, GPUs, TPUs). -2. Rigorously Documented. +1. Rigorously tested (CPUs, GPUs, TPUs). +2. Rigorously documented. 3. Standardized via PyTorch Lightning. 4. Optimized for speed. 5. Checked for correctness. @@ -351,7 +351,7 @@ In case your job or research doesn't need a "hammer", we offer implementations o which benefit from lightning's multi-GPU and TPU support. So, now you can run huge workloads scalably, without needing to do any engineering. -For instance, here we can run Logistic Regression on Imagenet (each epoch takes about 3 minutes)! +For instance, here we can run logistic Regression on Imagenet (each epoch takes about 3 minutes)! .. code-block:: python @@ -414,7 +414,7 @@ But more importantly, you can scale up to many GPUs, TPUs or even CPUs Logistic Regression ^^^^^^^^^^^^^^^^^^^ -Here's an example for Logistic regression +Here's an example for logistic regression .. code-block:: python @@ -436,7 +436,7 @@ Here's an example for Logistic regression trainer.test(test_dataloaders=dm.test_dataloader(batch_size=12)) -Any input will be flattened across all dimensions except the firs one (batch). +Any input will be flattened across all dimensions except the first one (batch). This means images, sound, etc... work out of the box. .. code-block:: python diff --git a/docs/source/models.rst b/docs/source/models.rst index 09ae1888c5..924b39de4d 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -15,7 +15,7 @@ by adding your contribution to bolts you get these **additional** benefits! 6. We'll pretrain expensive models for you and host weights. 7. We will improve the speed of your models! 8. Eligible for invited talks to discuss your implementation. - 9. Lightning Swag + involvement in the broader contributor community :) + 9. Lightning swag + involvement in the broader contributor community :) .. note:: You still get to keep your attribution and be recognized for your work! @@ -98,7 +98,7 @@ We request that each contribution have: - Your name and your team's name as the implementation authors. - Your team's affiliation - Any generated examples, or result plots. - - Hyperparameters configurations for the results. + - Hyperparameter configurations for the results. Thank you for all your amazing contributions! diff --git a/docs/source/reinforce_learn.rst b/docs/source/reinforce_learn.rst index 827feb395a..4737b60764 100644 --- a/docs/source/reinforce_learn.rst +++ b/docs/source/reinforce_learn.rst @@ -29,8 +29,8 @@ Contributions by: `Donal Byrne `_ DQN Models ---------- -The following models are based on DQN. DQN uses Value based learning where it is deciding what action to take based -on the models current learned value (V), or the state action value (Q) of the current state. These Values are defined +The following models are based on DQN. DQN uses value based learning where it is deciding what action to take based +on the model's current learned value (V), or the state action value (Q) of the current state. These values are defined as the discounted total reward of the agents state or state action pair. --------------- @@ -47,12 +47,12 @@ The DQN was introduced in `Playing Atari with Deep Reinforcement Learning ` the network uses two heads, one outputs the value state and the other outputs the advantage. This leads to better training stability, faster convergence and overall better results. The V head outputs a single scalar (the state value), while the advantage head outputs a tensor equal to the size of the action space, containing @@ -189,14 +189,14 @@ by subtracting the mean advantage from the Q value. This essentially pulls the m Dueling DQN Benefits ~~~~~~~~~~~~~~~~~~~~ -- Ability to efficiently learn the state value function. In the dueling network, every Q update also updates the Value - stream, where as in DQN only the value of the chosen action is updated. This provides a better approximation of the - values +- Ability to efficiently learn the state value function. In the dueling network, every Q update also updates the value + stream, where as in DQN only the value of the chosen action is updated. This provides a better approximation of the + values - The differences between total Q values for a given state are quite small in relation to the magnitude of Q. The - difference in the Q values between the best action and the second best action can be very small, while the average - state value can be much larger. The differences in scale can introduce noise, which may lead to the greedy policy - switching the priority of these actions. The seperate estimators for state value and advantage makes the Dueling - DQN robust to this type of scenario + difference in the Q values between the best action and the second best action can be very small, while the average + state value can be much larger. The differences in scale can introduce noise, which may lead to the greedy policy + switching the priority of these actions. The seperate estimators for state value and advantage makes the Dueling + DQN robust to this type of scenario Dueling DQN Results ~~~~~~~~~~~~~~~~~~~ @@ -255,11 +255,11 @@ Vlad Mnih, Remi Munos, Demis Hassabis, Olivier Pietquin, Charles Blundell, Shane Original implementation by: `Donal Byrne `_ Up until now the DQN agent uses a seperate exploration policy, generally epsilon-greedy where start and end values -are set for its exploration. [Noisy Networks For Exploration](https://arxiv.org/abs/1706.10295) introduces -a new exploration strategy by adding noise parameters to the weightsof the fully connect layers which get updated +are set for its exploration. `Noisy Networks For Exploration ` introduces +a new exploration strategy by adding noise parameters to the weights of the fully connect layers which get updated during backpropagation of the network. The noise parameters drive the exploration of the network instead of simply taking random actions more frequently at the start of training and -less frequently towards the end.The of authors of +less frequently towards the end. The of authors of propose two ways of doing this. During the optimization step a new set of noisy parameters are sampled. During training the agent acts according to @@ -270,23 +270,23 @@ distribution. The authors propose two methods of injecting noise to the network. 1) Independent Gaussian Noise: This injects noise per weight. For each weight a random value is taken from - the distribution. Noise parameters are stored inside the layer and are updated during backpropagation. - The output of the layer is calculated as normal. + the distribution. Noise parameters are stored inside the layer and are updated during backpropagation. + The output of the layer is calculated as normal. 2) Factorized Gaussian Noise: This injects nosier per input/ouput. In order to minimize the number of random values - this method stores two random vectors, one with the size of the input and the other with the size of the output. - Using these two vectors, a random matrix is generated for the layer by calculating the outer products of the vector + this method stores two random vectors, one with the size of the input and the other with the size of the output. + Using these two vectors, a random matrix is generated for the layer by calculating the outer products of the vector Noisy DQN Benefits ~~~~~~~~~~~~~~~~~~ - Improved exploration function. Instead of just performing completely random actions, we add decreasing amount of noise - and uncertainty to our policy allowing to explore while still utilising its policy + and uncertainty to our policy allowing to explore while still utilising its policy. - The fact that this method is automatically tuned means that we do not have to tune hyper parameters for - epsilon-greedy! + epsilon-greedy! .. note:: - for now I have just implemented the Independant Gaussian as it has been reported there isn't much difference + For now I have just implemented the Independant Gaussian as it has been reported there isn't much difference in results for these benchmark environments. In order to update the basic DQN to a Noisy DQN we need to do the following @@ -349,8 +349,8 @@ pair using a single step which looks like this Q(s_t,a_t)=r_t+{\gamma}\max_aQ(s_{t+1},a_{t+1}) -but because the Q function is recursive we can continue to roll this out into multiple steps, looking at the expected - return for each step into the future. +but because the Q function is recursive we can continue to roll this out into multiple steps, looking at the expected +return for each step into the future. .. math:: @@ -373,14 +373,14 @@ method like DQN with a large replay buffer will make this even worse, as there i training on experiences using an old policy that was worse than our current policy. So we need to strike a balance between looking far enough ahead to improve the convergence of our agent, but not so far - that are updates become unstable. In general, small values of 2-4 work best. +that are updates become unstable. In general, small values of 2-4 work best. N-Step Benefits ~~~~~~~~~~~~~~~ - Multi-Step learning is capable of learning faster than typical 1 step learning methods. - Note that this method introduces a new hyperparameter n. Although n=4 is generally a good starting point and provides - good results across the board. + good results across the board. N-Step Results ~~~~~~~~~~~~~~ @@ -464,7 +464,7 @@ PER Benefits ~~~~~~~~~~~~ - The benefits of this technique are that the agent sees more samples that it struggled with and gets more - chances to improve upon it. + chances to improve upon it. **Memory Buffer** @@ -500,10 +500,10 @@ on an optimal policy faster. **DQN vs PER DQN: Pong** In comparison to the base DQN, the PER DQN does show improved stability and performance. As expected, the loss - of the PER DQN is siginificantly lower. This is the main objective of PER by focusing on experiences with high loss. +of the PER DQN is siginificantly lower. This is the main objective of PER by focusing on experiences with high loss. It is important to note that loss is not the only metric we should be looking at. Although the agent may have very - low loss during training, it may still perform poorly due to lack of exploration. +low loss during training, it may still perform poorly due to lack of exploration. .. image:: _images/rl_benchmark/pong_per_dqn_baseline_v1_results_comp.jpg :width: 800 @@ -535,7 +535,7 @@ suggested by our policy gradient in order to find a policy that produces the hig Policy Gradient Key Points: - Outputs a distribution of actions instead of discrete Q values - Optimizes the policy directly, instead of indirectly through the optimization of Q values - - The policy distribution of actions allows the model to handle more complex action spaces, such as continuos actions + - The policy distribution of actions allows the model to handle more complex action spaces, such as continuous actions - The policy distribution introduces stochasticity, providing natural exploration to the model - The policy distribution provides a more stable update as a change in weights will only change the total distribution slightly, as opposed to changing weights based on the Q value of state S will change all Q values with similar states. @@ -570,17 +570,16 @@ algorithm is as follows: .. math:: L = - \sum_{k,t} Q_{k,t} \log(\pi(S_{k,t}, A_{k,t})) - + 5. Perform SGD on the loss and repeat. - What this loss function is saying is simply that we want to take the log probability of action A at state S given our policy (network output). This is then scaled by the discounted reward that we calculated in the previous step. We then take the negative of our sum. This is because the loss is minimized during SGD, but we want to maximize our policy. .. note:: - the current implementation does not actually wait for the batch episodes the complete every time as we pass in a + The current implementation does not actually wait for the batch episodes the complete every time as we pass in a fixed batch size. For the time being we simply use a large batch size to accomodate this. This approach still works well for simple tasks as it still manages to get an accurate Q value by using a large batch size, but it is not as accurate or completely correct. This will be updated in a later version. diff --git a/pl_bolts/models/rl/dqn_model.py b/pl_bolts/models/rl/dqn_model.py index 52de30e975..01b4a68277 100644 --- a/pl_bolts/models/rl/dqn_model.py +++ b/pl_bolts/models/rl/dqn_model.py @@ -92,8 +92,7 @@ def __init__( Note: This example is based on: - https://github.com/PacktPublishing/Deep-Reinforcement-Learning-Hands-On-Second-Edition\ - /blob/master/Chapter06/02_dqn_pong.py + https://github.com/PacktPublishing/Deep-Reinforcement-Learning-Hands-On-Second-Edition/blob/master/Chapter06/02_dqn_pong.py Note: Currently only supports CPU and single GPU training with `distributed_backend=dp` diff --git a/pl_bolts/models/rl/per_dqn_model.py b/pl_bolts/models/rl/per_dqn_model.py index 7705a6846a..07ad80d564 100644 --- a/pl_bolts/models/rl/per_dqn_model.py +++ b/pl_bolts/models/rl/per_dqn_model.py @@ -54,8 +54,7 @@ class PERDQN(DQN): .. note:: This example is based on: - https://github.com/PacktPublishing/Deep-Reinforcement-Learning-Hands-On-Second-Edition\ - /blob/master/Chapter08/05_dqn_prio_replay.py + https://github.com/PacktPublishing/Deep-Reinforcement-Learning-Hands-On-Second-Edition/blob/master/Chapter08/05_dqn_prio_replay.py .. note:: Currently only supports CPU and single GPU training with `distributed_backend=dp` diff --git a/pl_bolts/models/self_supervised/byol/byol_module.py b/pl_bolts/models/self_supervised/byol/byol_module.py index b6fb83cf52..04e7f81ae2 100644 --- a/pl_bolts/models/self_supervised/byol/byol_module.py +++ b/pl_bolts/models/self_supervised/byol/byol_module.py @@ -29,7 +29,7 @@ def __init__(self, PyTorch Lightning implementation of `Bootstrap Your Own Latent (BYOL) `_ - Paper authors: Jean-Bastien Grill ,Florian Strub, Florent Altché, Corentin Tallec, Pierre H. Richemond, \ + Paper authors: Jean-Bastien Grill, Florian Strub, Florent Altché, Corentin Tallec, Pierre H. Richemond, \ Elena Buchatskaya, Carl Doersch, Bernardo Avila Pires, Zhaohan Daniel Guo, Mohammad Gheshlaghi Azar, \ Bilal Piot, Koray Kavukcuoglu, Rémi Munos, Michal Valko. From 7b1d8951d832366344cf863d8f2f04d3ece04768 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 12 Oct 2020 07:05:59 -0400 Subject: [PATCH 06/32] add random datasets (#266) * enabled manual returns * enabled manual returns --- docs/source/dataloaders.rst | 8 ---- docs/source/index.rst | 16 ++++++- pl_bolts/datamodules/__init__.py | 1 - pl_bolts/datamodules/dummy_dataset.py | 68 --------------------------- 4 files changed, 15 insertions(+), 78 deletions(-) delete mode 100644 pl_bolts/datamodules/dummy_dataset.py diff --git a/docs/source/dataloaders.rst b/docs/source/dataloaders.rst index 70aab6ae78..4101003ba7 100644 --- a/docs/source/dataloaders.rst +++ b/docs/source/dataloaders.rst @@ -14,11 +14,3 @@ Example: .. autoclass:: pl_bolts.datamodules.async_dataloader.AsynchronousLoader :noindex: - ------------------- - -DummyDataset ------------- - -.. autoclass:: pl_bolts.datamodules.dummy_dataset.DummyDataset - :noindex: diff --git a/docs/source/index.rst b/docs/source/index.rst index e24bbeb38f..cf45a1d243 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -33,6 +33,13 @@ PyTorch-Lightning-Bolts documentation sklearn_datamodule vision_datamodules +.. toctree:: + :maxdepth: 2 + :name: datasets + :caption: Datasets + + datasets + .. toctree:: :maxdepth: 2 :name: dataloaders @@ -53,8 +60,14 @@ PyTorch-Lightning-Bolts documentation :caption: Models models_howto - autoencoders classic_ml + +.. toctree:: + :maxdepth: 2 + :name: vision + :caption: Vision models + + autoencoders convolutional gans reinforce_learn @@ -91,6 +104,7 @@ Indices and tables readme api/pl_bolts.callbacks api/pl_bolts.datamodules + api/pl_bolts.datasets api/pl_bolts.metrics api/pl_bolts.models api/pl_bolts.callbacks diff --git a/pl_bolts/datamodules/__init__.py b/pl_bolts/datamodules/__init__.py index d3523ff41a..2e3447d2ac 100644 --- a/pl_bolts/datamodules/__init__.py +++ b/pl_bolts/datamodules/__init__.py @@ -1,5 +1,4 @@ from pl_bolts.datamodules.async_dataloader import AsynchronousLoader -from pl_bolts.datamodules.dummy_dataset import DummyDataset, DummyDetectionDataset try: from pl_bolts.datamodules.binary_mnist_datamodule import BinaryMNISTDataModule diff --git a/pl_bolts/datamodules/dummy_dataset.py b/pl_bolts/datamodules/dummy_dataset.py deleted file mode 100644 index 771a7fdbb7..0000000000 --- a/pl_bolts/datamodules/dummy_dataset.py +++ /dev/null @@ -1,68 +0,0 @@ -import torch -from torch.utils.data import Dataset, DataLoader - - -class DummyDataset(Dataset): - def __init__(self, *shapes, num_samples=10000): - """ - Generate a dummy dataset - - Args: - *shapes: list of shapes - num_samples: how many samples to use in this dataset - - Example:: - - from pl_bolts.datamodules import DummyDataset - - # mnist dims - >>> ds = DummyDataset((1, 28, 28), (1,)) - >>> dl = DataLoader(ds, batch_size=7) - ... - >>> batch = next(iter(dl)) - >>> x, y = batch - >>> x.size() - torch.Size([7, 1, 28, 28]) - >>> y.size() - torch.Size([7, 1]) - """ - super().__init__() - self.shapes = shapes - self.num_samples = num_samples - - def __len__(self): - return self.num_samples - - def __getitem__(self, idx): - samples = [] - for shape in self.shapes: - sample = torch.rand(*shape) - samples.append(sample) - - return samples - - -class DummyDetectionDataset(Dataset): - def __init__( - self, img_shape=(3, 256, 256), num_boxes=1, num_classes=2, num_samples=10000 - ): - super().__init__() - self.img_shape = img_shape - self.num_samples = num_samples - self.num_boxes = num_boxes - self.num_classes = num_classes - - def __len__(self): - return self.num_samples - - def _random_bbox(self): - c, h, w = self.img_shape - xs = torch.randint(w, (2,)) - ys = torch.randint(h, (2,)) - return [min(xs), min(ys), max(xs), max(ys)] - - def __getitem__(self, idx): - img = torch.rand(self.img_shape) - boxes = torch.tensor([self._random_bbox() for _ in range(self.num_boxes)]) - labels = torch.randint(self.num_classes, (self.num_boxes,)) - return img, {"boxes": boxes, "labels": labels} From a8e3fb5e2692a8193b28dc1667e1df9bf6ffb04b Mon Sep 17 00:00:00 2001 From: Nathan Raw Date: Mon, 12 Oct 2020 05:07:19 -0600 Subject: [PATCH 07/32] Updates for lightning 0.10.0 (#264) * :art: refactor to use self.log instead of results obj * :pushpin: pin reqs * :bug: use functional accuracy * :bug: remove result * :bug: clean hparams up to fix cpc --- .../autoencoders/basic_ae/basic_ae_module.py | 10 ++-- .../basic_vae/basic_vae_module.py | 10 ++-- .../models/gans/basic/basic_gan_module.py | 10 ++-- .../models/regression/logistic_regression.py | 2 +- .../self_supervised/byol/byol_module.py | 10 ++-- .../models/self_supervised/cpc/cpc_module.py | 48 ++++++++----------- .../self_supervised/simclr/simclr_module.py | 10 ++-- .../models/self_supervised/ssl_finetuner.py | 15 +++--- requirements/base.txt | 2 +- 9 files changed, 48 insertions(+), 69 deletions(-) diff --git a/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py b/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py index d554769b7e..66c3a3a113 100644 --- a/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py +++ b/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py @@ -106,17 +106,15 @@ def step(self, batch, batch_idx): def training_step(self, batch, batch_idx): loss, logs = self.step(batch, batch_idx) - result = pl.TrainResult(minimize=loss) - result.log_dict( + self.log_dict( {f"train_{k}": v for k, v in logs.items()}, on_step=True, on_epoch=False ) - return result + return loss def validation_step(self, batch, batch_idx): loss, logs = self.step(batch, batch_idx) - result = pl.EvalResult(checkpoint_on=loss) - result.log_dict({f"val_{k}": v for k, v in logs.items()}) - return result + self.log_dict({f"val_{k}": v for k, v in logs.items()}) + return loss def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=self.lr) diff --git a/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py b/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py index 40ba1be428..ab6671d000 100644 --- a/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py +++ b/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py @@ -139,17 +139,15 @@ def step(self, batch, batch_idx): def training_step(self, batch, batch_idx): loss, logs = self.step(batch, batch_idx) - result = pl.TrainResult(minimize=loss) - result.log_dict( + self.log_dict( {f"train_{k}": v for k, v in logs.items()}, on_step=True, on_epoch=False ) - return result + return loss def validation_step(self, batch, batch_idx): loss, logs = self.step(batch, batch_idx) - result = pl.EvalResult(checkpoint_on=loss) - result.log_dict({f"val_{k}": v for k, v in logs.items()}) - return result + self.log_dict({f"val_{k}": v for k, v in logs.items()}) + return loss def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=self.lr) diff --git a/pl_bolts/models/gans/basic/basic_gan_module.py b/pl_bolts/models/gans/basic/basic_gan_module.py index 1459327daf..7311cb260b 100644 --- a/pl_bolts/models/gans/basic/basic_gan_module.py +++ b/pl_bolts/models/gans/basic/basic_gan_module.py @@ -136,18 +136,16 @@ def generator_step(self, x): # log to prog bar on each step AND for the full epoch # use the generator loss for checkpointing - result = pl.TrainResult(minimize=g_loss, checkpoint_on=g_loss) - result.log('g_loss', g_loss, on_epoch=True, prog_bar=True) - return result + self.log('g_loss', g_loss, on_epoch=True, prog_bar=True) + return g_loss def discriminator_step(self, x): # Measure discriminator's ability to classify real from generated samples d_loss = self.discriminator_loss(x) # log to prog bar on each step AND for the full epoch - result = pl.TrainResult(minimize=d_loss) - result.log('d_loss', d_loss, on_epoch=True, prog_bar=True) - return result + self.log('d_loss', d_loss, on_epoch=True, prog_bar=True) + return d_loss def configure_optimizers(self): lr = self.hparams.learning_rate diff --git a/pl_bolts/models/regression/logistic_regression.py b/pl_bolts/models/regression/logistic_regression.py index c9b2a2ff4f..79626cd143 100644 --- a/pl_bolts/models/regression/logistic_regression.py +++ b/pl_bolts/models/regression/logistic_regression.py @@ -2,7 +2,7 @@ import pytorch_lightning as pl import torch -from pytorch_lightning.metrics.classification import accuracy +from pytorch_lightning.metrics.functional import accuracy from torch import nn from torch.nn import functional as F from torch.optim import Adam diff --git a/pl_bolts/models/self_supervised/byol/byol_module.py b/pl_bolts/models/self_supervised/byol/byol_module.py index 04e7f81ae2..95c68bbee7 100644 --- a/pl_bolts/models/self_supervised/byol/byol_module.py +++ b/pl_bolts/models/self_supervised/byol/byol_module.py @@ -136,19 +136,17 @@ def training_step(self, batch, batch_idx): loss_a, loss_b, total_loss = self.shared_step(batch, batch_idx) # log results - result = pl.TrainResult(minimize=total_loss) - result.log_dict({'1_2_loss': loss_a, '2_1_loss': loss_b, 'train_loss': total_loss}) + self.log_dict({'1_2_loss': loss_a, '2_1_loss': loss_b, 'train_loss': total_loss}) - return result + return total_loss def validation_step(self, batch, batch_idx): loss_a, loss_b, total_loss = self.shared_step(batch, batch_idx) # log results - result = pl.EvalResult(early_stop_on=total_loss, checkpoint_on=total_loss) - result.log_dict({'1_2_loss': loss_a, '2_1_loss': loss_b, 'train_loss': total_loss}) + self.log_dict({'1_2_loss': loss_a, '2_1_loss': loss_b, 'train_loss': total_loss}) - return result + return total_loss def configure_optimizers(self): optimizer = Adam(self.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay) diff --git a/pl_bolts/models/self_supervised/cpc/cpc_module.py b/pl_bolts/models/self_supervised/cpc/cpc_module.py index b94d4611c0..e6aa3b5b0c 100644 --- a/pl_bolts/models/self_supervised/cpc/cpc_module.py +++ b/pl_bolts/models/self_supervised/cpc/cpc_module.py @@ -35,7 +35,7 @@ class CPCV2(pl.LightningModule): def __init__( self, datamodule: pl.LightningDataModule = None, - encoder: Union[str, torch.nn.Module, pl.LightningModule] = 'cpc_encoder', + encoder_name: str = 'cpc_encoder', patch_size: int = 8, patch_overlap: int = 4, online_ft: int = True, @@ -50,7 +50,7 @@ def __init__( """ Args: datamodule: A Datamodule (optional). Otherwise set the dataloaders directly - encoder: A string for any of the resnets in torchvision, or the original CPC encoder, + encoder_name: A string for any of the resnets in torchvision, or the original CPC encoder, or a custon nn.Module encoder patch_size: How big to make the image patches patch_overlap: How much overlap should each patch have. @@ -66,28 +66,20 @@ def __init__( super().__init__() self.save_hyperparameters() + # HACK - datamodule not pickleable so we remove it from hparams. + # TODO - remove datamodule from init. data should be decoupled from models. + del self.hparams['datamodule'] + self.online_evaluator = self.hparams.online_ft if pretrained: self.hparams.dataset = pretrained self.online_evaluator = True - # link data - # if datamodule is None: - # datamodule = CIFAR10DataModule( - # self.hparams.data_dir, - # num_workers=self.hparams.num_workers, - # batch_size=batch_size - # ) - # datamodule.train_transforms = CPCTrainTransformsCIFAR10() - # datamodule.val_transforms = CPCEvalTransformsCIFAR10() assert datamodule self.datamodule = datamodule - # init encoder - self.encoder = encoder - if isinstance(encoder, str): - self.encoder = self.init_encoder() + self.encoder = self.init_encoder() # info nce loss c, h = self.__compute_final_nb_c(self.hparams.patch_size) @@ -97,20 +89,22 @@ def __init__( self.num_classes = self.datamodule.num_classes if pretrained: - self.load_pretrained(encoder) + self.load_pretrained(self.hparams.encoder_name) + + print(self.hparams) - def load_pretrained(self, encoder): + def load_pretrained(self, encoder_name): available_weights = {'resnet18'} - if encoder in available_weights: - load_pretrained(self, f'CPCV2-{encoder}') - elif available_weights not in available_weights: - rank_zero_warn(f'{encoder} not yet available') + if encoder_name in available_weights: + load_pretrained(self, f'CPCV2-{encoder_name}') + elif encoder_name not in available_weights: + rank_zero_warn(f'{encoder_name} not yet available') def init_encoder(self): dummy_batch = torch.zeros((2, 3, self.hparams.patch_size, self.hparams.patch_size)) - encoder_name = self.hparams.encoder + encoder_name = self.hparams.encoder_name if encoder_name == 'cpc_encoder': return cpc_resnet101(dummy_batch) else: @@ -160,18 +154,16 @@ def training_step(self, batch, batch_nb): nce_loss = self.shared_step(batch) # result - result = pl.TrainResult(nce_loss) - result.log('train_nce_loss', nce_loss) - return result + self.log('train_nce_loss', nce_loss) + return nce_loss def validation_step(self, batch, batch_nb): # calculate loss nce_loss = self.shared_step(batch) # result - result = pl.EvalResult(checkpoint_on=nce_loss) - result.log('val_nce', nce_loss, prog_bar=True) - return result + self.log('val_nce', nce_loss, prog_bar=True) + return nce_loss def shared_step(self, batch): try: diff --git a/pl_bolts/models/self_supervised/simclr/simclr_module.py b/pl_bolts/models/self_supervised/simclr/simclr_module.py index 7fbe562827..582883991a 100644 --- a/pl_bolts/models/self_supervised/simclr/simclr_module.py +++ b/pl_bolts/models/self_supervised/simclr/simclr_module.py @@ -157,16 +157,14 @@ def forward(self, x): def training_step(self, batch, batch_idx): loss = self.shared_step(batch, batch_idx) - result = pl.TrainResult(minimize=loss) - result.log('train_loss', loss, on_epoch=True) - return result + self.log('train_loss', loss, on_epoch=True) + return loss def validation_step(self, batch, batch_idx): loss = self.shared_step(batch, batch_idx) - result = pl.EvalResult(checkpoint_on=loss) - result.log('avg_val_loss', loss) - return result + self.log('avg_val_loss', loss) + return loss def shared_step(self, batch, batch_idx): (img1, img2), y = batch diff --git a/pl_bolts/models/self_supervised/ssl_finetuner.py b/pl_bolts/models/self_supervised/ssl_finetuner.py index d3a3e95377..f07e697a42 100644 --- a/pl_bolts/models/self_supervised/ssl_finetuner.py +++ b/pl_bolts/models/self_supervised/ssl_finetuner.py @@ -59,21 +59,18 @@ def on_train_epoch_start(self) -> None: def training_step(self, batch, batch_idx): loss, acc = self.shared_step(batch) - result = pl.TrainResult(loss) - result.log('train_acc', acc, prog_bar=True) - return result + self.log('train_acc', acc, prog_bar=True) + return loss def validation_step(self, batch, batch_idx): loss, acc = self.shared_step(batch) - result = pl.EvalResult(checkpoint_on=loss, early_stop_on=loss) - result.log_dict({'val_acc': acc, 'val_loss': loss}, prog_bar=True) - return result + self.log_dict({'val_acc': acc, 'val_loss': loss}, prog_bar=True) + return loss def test_step(self, batch, batch_idx): loss, acc = self.shared_step(batch) - result = pl.EvalResult() - result.log_dict({'test_acc': acc, 'test_loss': loss}) - return result + self.log_dict({'test_acc': acc, 'test_loss': loss}) + return loss def shared_step(self, batch): x, y = batch diff --git a/requirements/base.txt b/requirements/base.txt index 62766e3de2..c434a7c377 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -1,2 +1,2 @@ -pytorch-lightning>=0.9.1rc3 +pytorch-lightning>=0.10.0 torch>=1.6 \ No newline at end of file From 2664934c75bc7e823f4b4e5ae2b0487cf19a51a7 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 12 Oct 2020 07:09:58 -0400 Subject: [PATCH 08/32] update mnist --- pl_bolts/models/mnist_module.py | 35 ++++++--------------------------- 1 file changed, 6 insertions(+), 29 deletions(-) diff --git a/pl_bolts/models/mnist_module.py b/pl_bolts/models/mnist_module.py index 365b481437..3dc71b22b1 100644 --- a/pl_bolts/models/mnist_module.py +++ b/pl_bolts/models/mnist_module.py @@ -36,43 +36,20 @@ def training_step(self, batch, batch_idx): x, y = batch y_hat = self(x) loss = F.cross_entropy(y_hat, y) - tensorboard_logs = {'train_loss': loss} - progress_bar_metrics = tensorboard_logs - return { - 'loss': loss, - 'log': tensorboard_logs, - 'progress_bar': progress_bar_metrics - } + self.log('train_loss', loss) + return loss def validation_step(self, batch, batch_idx): x, y = batch y_hat = self(x) - return {'val_loss': F.cross_entropy(y_hat, y)} - - def validation_epoch_end(self, outputs): - avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean() - tensorboard_logs = {'val_loss': avg_loss} - progress_bar_metrics = tensorboard_logs - return { - 'val_loss': avg_loss, - 'log': tensorboard_logs, - 'progress_bar': progress_bar_metrics - } + loss = F.cross_entropy(y_hat, y) + self.log('val_loss', loss) def test_step(self, batch, batch_idx): x, y = batch y_hat = self(x) - return {'test_loss': F.cross_entropy(y_hat, y)} - - def test_epoch_end(self, outputs): - avg_loss = torch.stack([x['test_loss'] for x in outputs]).mean() - tensorboard_logs = {'test_loss': avg_loss} - progress_bar_metrics = tensorboard_logs - return { - 'test_loss': avg_loss, - 'log': tensorboard_logs, - 'progress_bar': progress_bar_metrics - } + loss = F.cross_entropy(y_hat, y) + self.log('test_loss', loss) def configure_optimizers(self): return torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate) From b3349c6a3b02bb1b75d5711a226ae7fd4d98f505 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 12 Oct 2020 07:11:33 -0400 Subject: [PATCH 09/32] Update __init__.py --- pl_bolts/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pl_bolts/__init__.py b/pl_bolts/__init__.py index 271653f015..b3170a8502 100644 --- a/pl_bolts/__init__.py +++ b/pl_bolts/__init__.py @@ -2,7 +2,7 @@ import os -__version__ = '0.2.2' +__version__ = '0.2.3' __author__ = 'PyTorchLightning et al.' __author_email__ = 'name@pytorchlightning.ai' __license__ = 'Apache-2.0' From fda4600415e6da708e0b9bc85b15f6e04cdf467e Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 12 Oct 2020 07:22:05 -0400 Subject: [PATCH 10/32] enabled manual returns (#267) --- docs/source/datasets.rst | 41 ++++++++++++++++++++++++++++++++++++++++ pl_bolts/__init__.py | 5 +++-- 2 files changed, 44 insertions(+), 2 deletions(-) create mode 100644 docs/source/datasets.rst diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst new file mode 100644 index 0000000000..4e54095022 --- /dev/null +++ b/docs/source/datasets.rst @@ -0,0 +1,41 @@ +######## +Datasets +######## +Collection of useful datasets + +-------- + +********* +Debugging +********* +Use these datasets to debug + +DummyDataset +============ + +.. autoclass:: pl_bolts.datasets.dummy_dataset.DummyDataset + :noindex: + +DummyDetectionDataset +===================== + +.. autoclass:: pl_bolts.datasets.dummy_dataset.DummyDetectionDataset + :noindex: + +RandomDataset +============= + +.. autoclass:: pl_bolts.datasets.dummy_dataset.RandomDataset + :noindex: + +RandomDictDataset +================= + +.. autoclass:: pl_bolts.datasets.dummy_dataset.RandomDictDataset + :noindex: + +RandomDictStringDataset +======================= + +.. autoclass:: pl_bolts.datasets.dummy_dataset.RandomDictStringDataset + :noindex: diff --git a/pl_bolts/__init__.py b/pl_bolts/__init__.py index b3170a8502..31399af111 100644 --- a/pl_bolts/__init__.py +++ b/pl_bolts/__init__.py @@ -45,12 +45,13 @@ else: # from pl_bolts.models.mnist_module import LitMNIST - from pl_bolts import models, metrics, callbacks, datamodules, transforms + from pl_bolts import models, metrics, callbacks, datamodules, transforms, datasets __all__ = [ # 'LitMNIST', 'models', 'metrics', 'callbacks', - 'datamodules' + 'datamodules', + 'datasets' ] From f092b8a34dea245bde4e4a99e9cc580656f9740f Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 12 Oct 2020 07:22:18 -0400 Subject: [PATCH 11/32] Update __init__.py --- pl_bolts/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pl_bolts/__init__.py b/pl_bolts/__init__.py index 31399af111..7972098997 100644 --- a/pl_bolts/__init__.py +++ b/pl_bolts/__init__.py @@ -2,7 +2,7 @@ import os -__version__ = '0.2.3' +__version__ = '0.2.4' __author__ = 'PyTorchLightning et al.' __author_email__ = 'name@pytorchlightning.ai' __license__ = 'Apache-2.0' From 16302a567ab61e23f96bd1cc8ce1c61784bdf95d Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 12 Oct 2020 07:26:26 -0400 Subject: [PATCH 12/32] enabled manual returns (#268) --- .gitignore | 1 - pl_bolts/__init__.py | 2 +- pl_bolts/datasets/__init__.py | 6 ++ pl_bolts/datasets/dummy_dataset.py | 162 +++++++++++++++++++++++++++++ tests/datasets/__init__.py | 0 tests/datasets/test_datasets.py | 34 ++++++ 6 files changed, 203 insertions(+), 2 deletions(-) create mode 100644 pl_bolts/datasets/__init__.py create mode 100644 pl_bolts/datasets/dummy_dataset.py create mode 100644 tests/datasets/__init__.py create mode 100644 tests/datasets/test_datasets.py diff --git a/.gitignore b/.gitignore index 2c179cd36b..96f7417d80 100644 --- a/.gitignore +++ b/.gitignore @@ -138,7 +138,6 @@ MNIST # Lightning logs lightning_logs -datasets *.gz *-batches-py simclr.py diff --git a/pl_bolts/__init__.py b/pl_bolts/__init__.py index 7972098997..2945c2ba3c 100644 --- a/pl_bolts/__init__.py +++ b/pl_bolts/__init__.py @@ -53,5 +53,5 @@ 'metrics', 'callbacks', 'datamodules', - 'datasets' + 'datasets', ] diff --git a/pl_bolts/datasets/__init__.py b/pl_bolts/datasets/__init__.py new file mode 100644 index 0000000000..9962a8f717 --- /dev/null +++ b/pl_bolts/datasets/__init__.py @@ -0,0 +1,6 @@ +from pl_bolts.datasets.dummy_dataset import \ + (RandomDictStringDataset, + RandomDictDataset, + RandomDataset, + DummyDataset, + DummyDetectionDataset) \ No newline at end of file diff --git a/pl_bolts/datasets/dummy_dataset.py b/pl_bolts/datasets/dummy_dataset.py new file mode 100644 index 0000000000..3212b37795 --- /dev/null +++ b/pl_bolts/datasets/dummy_dataset.py @@ -0,0 +1,162 @@ +import torch +from torch.utils.data import Dataset, DataLoader + + +class DummyDataset(Dataset): + def __init__(self, *shapes, num_samples=10000): + """ + Generate a dummy dataset + + Args: + *shapes: list of shapes + num_samples: how many samples to use in this dataset + + Example:: + + from pl_bolts.datasets import DummyDataset + + # mnist dims + >>> ds = DummyDataset((1, 28, 28), (1,)) + >>> dl = DataLoader(ds, batch_size=7) + ... + >>> batch = next(iter(dl)) + >>> x, y = batch + >>> x.size() + torch.Size([7, 1, 28, 28]) + >>> y.size() + torch.Size([7, 1]) + """ + super().__init__() + self.shapes = shapes + self.num_samples = num_samples + + def __len__(self): + return self.num_samples + + def __getitem__(self, idx): + samples = [] + for shape in self.shapes: + sample = torch.rand(*shape) + samples.append(sample) + + return samples + + +class DummyDetectionDataset(Dataset): + def __init__( + self, img_shape=(3, 256, 256), num_boxes=1, num_classes=2, num_samples=10000 + ): + """ + Generate a dummy dataset for detection + + Args: + *shapes: list of shapes + num_samples: how many samples to use in this dataset + + Example:: + + from pl_bolts.datasets import DummyDetectionDataset + + >>> ds = DummyDetectionDataset() + >>> dl = DataLoader(ds, batch_size=7) + """ + super().__init__() + self.img_shape = img_shape + self.num_samples = num_samples + self.num_boxes = num_boxes + self.num_classes = num_classes + + def __len__(self): + return self.num_samples + + def _random_bbox(self): + c, h, w = self.img_shape + xs = torch.randint(w, (2,)) + ys = torch.randint(h, (2,)) + return [min(xs), min(ys), max(xs), max(ys)] + + def __getitem__(self, idx): + img = torch.rand(self.img_shape) + boxes = torch.tensor([self._random_bbox() for _ in range(self.num_boxes)]) + labels = torch.randint(self.num_classes, (self.num_boxes,)) + return img, {"boxes": boxes, "labels": labels} + + +class RandomDictDataset(Dataset): + def __init__(self, size, num_samples): + """ + Generate a dummy dataset with a dict structure + + Args: + size: tuple + num_samples: number of samples + + Example:: + + from pl_bolts.datasets import RandomDictDataset + + >>> ds = RandomDictDataset() + >>> dl = DataLoader(ds, batch_size=7) + """ + self.len = num_samples + self.data = torch.randn(num_samples, size) + + def __getitem__(self, index): + a = self.data[index] + b = a + 2 + return {'a': a, 'b': b} + + def __len__(self): + return self.len + + +class RandomDictStringDataset(Dataset): + def __init__(self, size, num_samples): + """ + Generate a dummy dataset with strings + + Args: + size: tuple + num_samples: number of samples + + Example:: + + from pl_bolts.datasets import RandomDictStringDataset + + >>> ds = RandomDictStringDataset() + >>> dl = DataLoader(ds, batch_size=7) + """ + self.len = num_samples + self.data = torch.randn(num_samples, size) + + def __getitem__(self, index): + return {"id": str(index), "x": self.data[index]} + + def __len__(self): + return self.len + + +class RandomDataset(Dataset): + def __init__(self, size, num_samples): + """ + Generate a dummy dataset + + Args: + size: tuple + num_samples: number of samples + + Example:: + + from pl_bolts.datasets import RandomDataset + + >>> ds = RandomDataset() + >>> dl = DataLoader(ds, batch_size=7) + """ + self.len = num_samples + self.data = torch.randn(num_samples, size) + + def __getitem__(self, index): + return self.data[index] + + def __len__(self): + return self.len diff --git a/tests/datasets/__init__.py b/tests/datasets/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py new file mode 100644 index 0000000000..c7adda3cda --- /dev/null +++ b/tests/datasets/test_datasets.py @@ -0,0 +1,34 @@ +from pl_bolts.datasets import DummyDataset, RandomDataset, RandomDictDataset, RandomDictStringDataset +from torch.utils.data import DataLoader + + +def test_dummy_ds(tmpdir): + ds = DummyDataset((1, 2), num_samples=100) + dl = DataLoader(ds) + + for b in dl: + pass + + +def test_rand_ds(tmpdir): + ds = RandomDataset(32, num_samples=100) + dl = DataLoader(ds) + + for b in dl: + pass + + +def test_rand_dict_ds(tmpdir): + ds = RandomDictDataset(32, num_samples=100) + dl = DataLoader(ds) + + for b in dl: + pass + + +def test_rand_str_dict_ds(tmpdir): + ds = RandomDictStringDataset(32, num_samples=100) + dl = DataLoader(ds) + + for b in dl: + pass From 6c6a6411fff509e4011ba450a640a62f0f2d23a2 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Mon, 12 Oct 2020 07:26:48 -0400 Subject: [PATCH 13/32] Update __init__.py --- pl_bolts/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pl_bolts/__init__.py b/pl_bolts/__init__.py index 2945c2ba3c..25f14672bd 100644 --- a/pl_bolts/__init__.py +++ b/pl_bolts/__init__.py @@ -2,7 +2,7 @@ import os -__version__ = '0.2.4' +__version__ = '0.2.5' __author__ = 'PyTorchLightning et al.' __author_email__ = 'name@pytorchlightning.ai' __license__ = 'Apache-2.0' From 4376a458ebd6b8616cd056508d496cb1cf514c7b Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 13 Oct 2020 12:45:06 +0200 Subject: [PATCH 14/32] add symlink to req. (#271) --- requirements.txt | 1 + 1 file changed, 1 insertion(+) mode change 100644 => 120000 requirements.txt diff --git a/requirements.txt b/requirements.txt deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/requirements.txt b/requirements.txt new file mode 120000 index 0000000000..2d8105f82c --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +requirements/base.txt \ No newline at end of file From 84dd111e9ff419cfcf6e7523b910e84c205fd5b7 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Tue, 13 Oct 2020 21:20:04 +0900 Subject: [PATCH 15/32] Fix typo (#263) --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 5e905ce783..28f398f7cc 100644 --- a/README.md +++ b/README.md @@ -57,7 +57,7 @@ Install bleeding-edge (no guarantees) pip install git+https://github.com/PytorchLightning/pytorch-lightning-bolts.git@master --upgrade ``` -In case you wan to have full experience you can install all optional packages at once +In case you want to have full experience you can install all optional packages at once ```bash pip install pytorch-lightning-bolts["extra"] ``` From 4f8ede65ca78ca98f82351218a3d376a5febc890 Mon Sep 17 00:00:00 2001 From: JackLangerman Date: Tue, 13 Oct 2020 08:21:03 -0400 Subject: [PATCH 16/32] Update README.md (#260) From ecb852df28d6d47cc3140c858c02f253ee496abc Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 13 Oct 2020 15:09:41 +0200 Subject: [PATCH 17/32] move requirements & bug fixing (#272) * move requirements * fix path * fix dummy * manifest * imports --- .github/workflows/ci_test-base.yml | 4 +- .github/workflows/ci_test-full.yml | 4 +- .github/workflows/code-format.yml | 4 +- .github/workflows/docs-check.yml | 6 +- MANIFEST.in | 1 + docs/source/conf.py | 2 +- pl_bolts/datasets/__init__.py | 13 +-- pl_bolts/datasets/dummy_dataset.py | 151 ++++++++++++++--------------- requirements.txt | 3 +- requirements/base.txt | 2 - requirements/devel.txt | 2 +- setup.py | 10 +- tests/models/test_detection.py | 2 +- 13 files changed, 102 insertions(+), 102 deletions(-) mode change 120000 => 100644 requirements.txt delete mode 100644 requirements/base.txt diff --git a/.github/workflows/ci_test-base.yml b/.github/workflows/ci_test-base.yml index 5103a13731..a6d808a816 100644 --- a/.github/workflows/ci_test-base.yml +++ b/.github/workflows/ci_test-base.yml @@ -41,14 +41,14 @@ jobs: uses: actions/cache@v2 with: path: ${{ steps.pip-cache.outputs.dir }} - key: ${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}-pip-${{ hashFiles('requirements/base.txt') }} + key: ${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}-pip-${{ hashFiles('requirements.txt') }} restore-keys: | ${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}-pip- - name: Install dependencies run: | python -m pip install --upgrade --user pip - pip install --requirement ./requirements/base.txt --quiet --find-links https://download.pytorch.org/whl/cpu/torch_stable.html --upgrade + pip install --requirement ./requirements.txt --quiet --find-links https://download.pytorch.org/whl/cpu/torch_stable.html --upgrade pip install --requirement ./requirements/test.txt --quiet --upgrade-strategy only-if-needed # pip install tox coverage python --version diff --git a/.github/workflows/ci_test-full.yml b/.github/workflows/ci_test-full.yml index f4b8f72d88..ba7a661a69 100644 --- a/.github/workflows/ci_test-full.yml +++ b/.github/workflows/ci_test-full.yml @@ -45,7 +45,7 @@ jobs: - name: Set min. dependencies if: matrix.requires == 'minimal' run: | - python -c "fpath = 'requirements/base.txt' ; req = open(fpath).read().replace('>=', '==') ; open(fpath, 'w').write(req)" + python -c "fpath = 'requirements.txt' ; req = open(fpath).read().replace('>=', '==') ; open(fpath, 'w').write(req)" python -c "fpath = 'requirements/models.txt' ; req = open(fpath).read().replace('>=', '==') ; open(fpath, 'w').write(req)" python -c "fpath = 'requirements/loggers.txt' ; req = open(fpath).read().replace('>=', '==') ; open(fpath, 'w').write(req)" python -c "fpath = 'requirements/test.txt' ; req = open(fpath).read().replace('>=', '==') ; open(fpath, 'w').write(req)" @@ -61,7 +61,7 @@ jobs: uses: actions/cache@v2 with: path: ${{ steps.pip-cache.outputs.dir }} - key: ${{ runner.os }}-pip-py${{ matrix.python-version }}-${{ matrix.requires }}-${{ hashFiles('requirements/base.txt') }}-${{ hashFiles('requirements/modules.txt') }} + key: ${{ runner.os }}-pip-py${{ matrix.python-version }}-${{ matrix.requires }}-${{ hashFiles('requirements.txt') }}-${{ hashFiles('requirements/modules.txt') }} restore-keys: | ${{ runner.os }}-pip-py${{ matrix.python-version }}-${{ matrix.requires }}- diff --git a/.github/workflows/code-format.yml b/.github/workflows/code-format.yml index 62f19d3001..813ee1b862 100644 --- a/.github/workflows/code-format.yml +++ b/.github/workflows/code-format.yml @@ -23,14 +23,14 @@ jobs: uses: actions/cache@v2 with: path: ~/.cache/pip - key: ${{ runner.os }}-pip-${{ hashFiles('requirements/base.txt') }} + key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }} restore-keys: | ${{ runner.os }}-pip- - name: Install dependencies run: | # python -m pip install --upgrade --user pip - pip install -r requirements/base.txt -U -f https://download.pytorch.org/whl/torch_stable.html -q + pip install -r requirements.txt -U -f https://download.pytorch.org/whl/torch_stable.html -q pip install flake8 python --version pip --version diff --git a/.github/workflows/docs-check.yml b/.github/workflows/docs-check.yml index 4de81bdc27..4aaac0d41e 100644 --- a/.github/workflows/docs-check.yml +++ b/.github/workflows/docs-check.yml @@ -36,7 +36,7 @@ jobs: # uses: actions/cache@v2 # with: # path: ~/.cache/pip -# key: ${{ runner.os }}-pip-${{ hashFiles('requirements/base.txt') }} +# key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }} # restore-keys: | # ${{ runner.os }}-pip- # @@ -75,13 +75,13 @@ jobs: uses: actions/cache@v2 with: path: ~/.cache/pip - key: ${{ runner.os }}-pip-${{ hashFiles('requirements/base.txt') }} + key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }} restore-keys: | ${{ runner.os }}-pip- - name: Install dependencies run: | - pip install --requirement requirements/base.txt --upgrade-strategy only-if-needed --find-links https://download.pytorch.org/whl/cpu/torch_stable.html --quiet + pip install --requirement requirements.txt --upgrade-strategy only-if-needed --find-links https://download.pytorch.org/whl/cpu/torch_stable.html --quiet pip install --requirement docs/requirements.txt # install Texlive, see https://linuxconfig.org/how-to-install-latex-on-ubuntu-20-04-focal-fossa-linux sudo apt-get update && sudo apt-get install -y texlive-latex-extra dvipng texlive-pictures diff --git a/MANIFEST.in b/MANIFEST.in index d3f4c4f33d..e306b2618d 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -25,6 +25,7 @@ recursive-exclude docs * exclude docs # Include the Requirements +include requirements.txt recursive-include requirements *.txt # Exclude build configs diff --git a/docs/source/conf.py b/docs/source/conf.py index f2dc1442d9..7114016ba1 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -328,7 +328,7 @@ def package_list_from_file(file): MOCK_PACKAGES = [] if SPHINX_MOCK_REQUIREMENTS: # mock also base packages when we are on RTD since we don't install them there - MOCK_PACKAGES += package_list_from_file(os.path.join(PATH_ROOT, 'requirements', 'base.txt')) + MOCK_PACKAGES += package_list_from_file(os.path.join(PATH_ROOT, 'requirements.txt')) MOCK_PACKAGES += package_list_from_file(os.path.join(PATH_ROOT, 'requirements', 'models.txt')) MOCK_PACKAGES += package_list_from_file(os.path.join(PATH_ROOT, 'requirements', 'loggers.txt')) diff --git a/pl_bolts/datasets/__init__.py b/pl_bolts/datasets/__init__.py index 9962a8f717..e2d319ce2f 100644 --- a/pl_bolts/datasets/__init__.py +++ b/pl_bolts/datasets/__init__.py @@ -1,6 +1,7 @@ -from pl_bolts.datasets.dummy_dataset import \ - (RandomDictStringDataset, - RandomDictDataset, - RandomDataset, - DummyDataset, - DummyDetectionDataset) \ No newline at end of file +from pl_bolts.datasets.dummy_dataset import ( + RandomDictStringDataset, + RandomDictDataset, + RandomDataset, + DummyDataset, + DummyDetectionDataset +) diff --git a/pl_bolts/datasets/dummy_dataset.py b/pl_bolts/datasets/dummy_dataset.py index 3212b37795..44b728422e 100644 --- a/pl_bolts/datasets/dummy_dataset.py +++ b/pl_bolts/datasets/dummy_dataset.py @@ -3,29 +3,29 @@ class DummyDataset(Dataset): - def __init__(self, *shapes, num_samples=10000): - """ - Generate a dummy dataset - - Args: - *shapes: list of shapes - num_samples: how many samples to use in this dataset - - Example:: - - from pl_bolts.datasets import DummyDataset - - # mnist dims - >>> ds = DummyDataset((1, 28, 28), (1,)) - >>> dl = DataLoader(ds, batch_size=7) - ... - >>> batch = next(iter(dl)) - >>> x, y = batch - >>> x.size() - torch.Size([7, 1, 28, 28]) - >>> y.size() - torch.Size([7, 1]) - """ + """ + Generate a dummy dataset + + Args: + *shapes: list of shapes + num_samples: how many samples to use in this dataset + + Example:: + + from pl_bolts.datasets import DummyDataset + + >>> # mnist dims + >>> ds = DummyDataset((1, 28, 28), (1, )) + >>> dl = DataLoader(ds, batch_size=7) + >>> # get first batch + >>> batch = next(iter(dl)) + >>> x, y = batch + >>> x.size() + torch.Size([7, 1, 28, 28]) + >>> y.size() + torch.Size([7, 1]) + """ + def __init__(self, *shapes, num_samples: int = 10000): super().__init__() self.shapes = shapes self.num_samples = num_samples @@ -33,33 +33,32 @@ def __init__(self, *shapes, num_samples=10000): def __len__(self): return self.num_samples - def __getitem__(self, idx): - samples = [] + def __getitem__(self, idx: int): + sample = [] for shape in self.shapes: - sample = torch.rand(*shape) - samples.append(sample) - - return samples + spl = torch.rand(*shape) + sample.append(spl) + return sample class DummyDetectionDataset(Dataset): - def __init__( - self, img_shape=(3, 256, 256), num_boxes=1, num_classes=2, num_samples=10000 - ): - """ - Generate a dummy dataset for detection + """ + Generate a dummy dataset for detection - Args: - *shapes: list of shapes - num_samples: how many samples to use in this dataset + Args: + *shapes: list of shapes + num_samples: how many samples to use in this dataset - Example:: + Example:: - from pl_bolts.datasets import DummyDetectionDataset + from pl_bolts.datasets import DummyDetectionDataset - >>> ds = DummyDetectionDataset() - >>> dl = DataLoader(ds, batch_size=7) - """ + >>> ds = DummyDetectionDataset() + >>> dl = DataLoader(ds, batch_size=7) + """ + def __init__( + self, img_shape: tuple = (3, 256, 256), num_boxes: int = 1, num_classes: int = 2, num_samples: int = 10000 + ): super().__init__() self.img_shape = img_shape self.num_samples = num_samples @@ -75,7 +74,7 @@ def _random_bbox(self): ys = torch.randint(h, (2,)) return [min(xs), min(ys), max(xs), max(ys)] - def __getitem__(self, idx): + def __getitem__(self, idx: int): img = torch.rand(self.img_shape) boxes = torch.tensor([self._random_bbox() for _ in range(self.num_boxes)]) labels = torch.randint(self.num_classes, (self.num_boxes,)) @@ -83,21 +82,21 @@ def __getitem__(self, idx): class RandomDictDataset(Dataset): - def __init__(self, size, num_samples): - """ - Generate a dummy dataset with a dict structure + """ + Generate a dummy dataset with a dict structure - Args: - size: tuple - num_samples: number of samples + Args: + size: tuple + num_samples: number of samples - Example:: + Example:: - from pl_bolts.datasets import RandomDictDataset + from pl_bolts.datasets import RandomDictDataset - >>> ds = RandomDictDataset() - >>> dl = DataLoader(ds, batch_size=7) - """ + >>> ds = RandomDictDataset(10) + >>> dl = DataLoader(ds, batch_size=7) + """ + def __init__(self, size: int, num_samples: int = 250): self.len = num_samples self.data = torch.randn(num_samples, size) @@ -111,21 +110,21 @@ def __len__(self): class RandomDictStringDataset(Dataset): - def __init__(self, size, num_samples): - """ - Generate a dummy dataset with strings + """ + Generate a dummy dataset with strings - Args: - size: tuple - num_samples: number of samples + Args: + size: tuple + num_samples: number of samples - Example:: + Example:: - from pl_bolts.datasets import RandomDictStringDataset + from pl_bolts.datasets import RandomDictStringDataset - >>> ds = RandomDictStringDataset() - >>> dl = DataLoader(ds, batch_size=7) - """ + >>> ds = RandomDictStringDataset(10) + >>> dl = DataLoader(ds, batch_size=7) + """ + def __init__(self, size: int, num_samples: int = 250): self.len = num_samples self.data = torch.randn(num_samples, size) @@ -137,21 +136,21 @@ def __len__(self): class RandomDataset(Dataset): - def __init__(self, size, num_samples): - """ - Generate a dummy dataset + """ + Generate a dummy dataset - Args: - size: tuple - num_samples: number of samples + Args: + size: tuple + num_samples: number of samples - Example:: + Example:: - from pl_bolts.datasets import RandomDataset + from pl_bolts.datasets import RandomDataset - >>> ds = RandomDataset() - >>> dl = DataLoader(ds, batch_size=7) - """ + >>> ds = RandomDataset(10) + >>> dl = DataLoader(ds, batch_size=7) + """ + def __init__(self, size: int, num_samples: int = 250): self.len = num_samples self.data = torch.randn(num_samples, size) diff --git a/requirements.txt b/requirements.txt deleted file mode 120000 index 2d8105f82c..0000000000 --- a/requirements.txt +++ /dev/null @@ -1 +0,0 @@ -requirements/base.txt \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000..c434a7c377 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,2 @@ +pytorch-lightning>=0.10.0 +torch>=1.6 \ No newline at end of file diff --git a/requirements/base.txt b/requirements/base.txt deleted file mode 100644 index c434a7c377..0000000000 --- a/requirements/base.txt +++ /dev/null @@ -1,2 +0,0 @@ -pytorch-lightning>=0.10.0 -torch>=1.6 \ No newline at end of file diff --git a/requirements/devel.txt b/requirements/devel.txt index 53b6b26d05..3574b167e4 100644 --- a/requirements/devel.txt +++ b/requirements/devel.txt @@ -1,5 +1,5 @@ # install all mandatory dependencies --r ./base.txt +-r ../requirements.txt # install all extra dependencies for full package experience -r ./models.txt diff --git a/setup.py b/setup.py index 456c6b9153..29931f94ff 100755 --- a/setup.py +++ b/setup.py @@ -19,8 +19,8 @@ import pl_bolts # noqa: E402 -def load_requirements(path_dir=PATH_ROOT, file_name='base.txt', comment_char='#'): - with open(os.path.join(path_dir, 'requirements', file_name), 'r') as file: +def load_requirements(path_dir=PATH_ROOT, file_name='requirements.txt', comment_char='#'): + with open(os.path.join(path_dir, file_name), 'r') as file: lines = [ln.strip() for ln in file.readlines()] reqs = [] for ln in lines: @@ -45,9 +45,9 @@ def load_long_describtion(): extras = { - 'loggers': load_requirements(file_name='loggers.txt'), - 'models': load_requirements(file_name='models.txt'), - 'test': load_requirements(file_name='test.txt'), + 'loggers': load_requirements(path_dir=os.path.join(PATH_ROOT, 'requirements'), file_name='loggers.txt'), + 'models': load_requirements(path_dir=os.path.join(PATH_ROOT, 'requirements'), file_name='models.txt'), + 'test': load_requirements(path_dir=os.path.join(PATH_ROOT, 'requirements'), file_name='test.txt'), } extras['extra'] = extras['models'] + extras['loggers'] extras['dev'] = extras['extra'] + extras['test'] diff --git a/tests/models/test_detection.py b/tests/models/test_detection.py index 61edf2e875..a312fbc9d7 100644 --- a/tests/models/test_detection.py +++ b/tests/models/test_detection.py @@ -3,7 +3,7 @@ import torch from torch.utils.data import DataLoader -from pl_bolts.datamodules import DummyDetectionDataset +from pl_bolts.datasets import DummyDetectionDataset from pl_bolts.models.detection import FasterRCNN From 07e33766b6fd9a25a68cd02c9c940725389a7a22 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 13 Oct 2020 16:30:03 +0200 Subject: [PATCH 18/32] skip hanging tests & update imports (#258) * fix imports * timeout * flake8 * skip * skip * skip * skip * Apply suggestions from code review --- .github/workflows/ci_test-full.yml | 2 +- pl_bolts/callbacks/variational.py | 2 +- pl_bolts/callbacks/vision/image_generation.py | 2 +- pl_bolts/datamodules/__init__.py | 2 +- .../datamodules/binary_mnist_datamodule.py | 6 ++-- pl_bolts/datamodules/cifar10_datamodule.py | 6 ++-- pl_bolts/datamodules/cifar10_dataset.py | 2 +- pl_bolts/datamodules/cityscapes_datamodule.py | 4 +-- .../datamodules/fashion_mnist_datamodule.py | 4 +-- pl_bolts/datamodules/imagenet_datamodule.py | 6 ++-- pl_bolts/datamodules/imagenet_dataset.py | 9 ++--- pl_bolts/datamodules/mnist_datamodule.py | 6 ++-- pl_bolts/datamodules/mnist_dataset.py | 9 ++--- pl_bolts/datamodules/sklearn_datamodule.py | 6 ++-- pl_bolts/datamodules/ssl_amdim_datasets.py | 4 +-- .../datamodules/ssl_imagenet_datamodule.py | 6 ++-- pl_bolts/datamodules/stl10_datamodule.py | 6 ++-- .../datamodules/vocdetection_datamodule.py | 6 ++-- pl_bolts/models/detection/__init__.py | 2 +- pl_bolts/models/detection/faster_rcnn.py | 2 +- pl_bolts/models/mnist_module.py | 2 +- .../models/regression/linear_regression.py | 7 ++-- .../models/regression/logistic_regression.py | 7 ++-- .../models/self_supervised/amdim/__init__.py | 2 +- .../self_supervised/amdim/amdim_module.py | 2 +- .../models/self_supervised/amdim/datasets.py | 2 +- .../self_supervised/amdim/transforms.py | 26 +++++++++++---- .../models/self_supervised/cpc/cpc_module.py | 7 ++-- .../models/self_supervised/cpc/transforms.py | 26 +++++++++++---- .../self_supervised/moco/moco2_module.py | 2 +- .../models/self_supervised/moco/transforms.py | 33 +++++++++++++------ pl_bolts/models/self_supervised/resnets.py | 2 +- .../self_supervised/simclr/simclr_module.py | 2 +- .../self_supervised/simclr/transforms.py | 16 ++++++--- pl_bolts/transforms/dataset_normalizations.py | 2 +- .../self_supervised/ssl_transforms.py | 7 ++-- pl_bolts/utils/semi_supervised.py | 2 +- tests/models/self_supervised/test_models.py | 6 ++++ tests/models/self_supervised/test_scripts.py | 3 ++ tests/models/test_executable_scripts.py | 3 ++ .../{test_vision_models.py => test_vision.py} | 0 41 files changed, 164 insertions(+), 87 deletions(-) rename tests/models/{test_vision_models.py => test_vision.py} (100%) diff --git a/.github/workflows/ci_test-full.yml b/.github/workflows/ci_test-full.yml index ba7a661a69..19c085ddcc 100644 --- a/.github/workflows/ci_test-full.yml +++ b/.github/workflows/ci_test-full.yml @@ -27,7 +27,7 @@ jobs: os: windows-2019 # Timeout: https://stackoverflow.com/a/59076067/4521646 - timeout-minutes: 35 + timeout-minutes: 45 steps: - uses: actions/checkout@v2 diff --git a/pl_bolts/callbacks/variational.py b/pl_bolts/callbacks/variational.py index 49d35f49fc..eb4469b0ed 100644 --- a/pl_bolts/callbacks/variational.py +++ b/pl_bolts/callbacks/variational.py @@ -5,7 +5,7 @@ try: import torchvision -except ImportError: +except ModuleNotFoundError: warn('You want to use `torchvision` which is not installed yet,' # pragma: no-cover ' install it with `pip install torchvision`.') diff --git a/pl_bolts/callbacks/vision/image_generation.py b/pl_bolts/callbacks/vision/image_generation.py index 947497089a..60f5c2e172 100644 --- a/pl_bolts/callbacks/vision/image_generation.py +++ b/pl_bolts/callbacks/vision/image_generation.py @@ -5,7 +5,7 @@ try: import torchvision -except ImportError: +except ModuleNotFoundError: warn('You want to use `torchvision` which is not installed yet,' # pragma: no-cover ' install it with `pip install torchvision`.') diff --git a/pl_bolts/datamodules/__init__.py b/pl_bolts/datamodules/__init__.py index 2e3447d2ac..cb810e1b6e 100644 --- a/pl_bolts/datamodules/__init__.py +++ b/pl_bolts/datamodules/__init__.py @@ -25,5 +25,5 @@ from pl_bolts.datamodules.kitti_dataset import KittiDataset from pl_bolts.datamodules.kitti_datamodule import KittiDataModule -except ImportError: +except ModuleNotFoundError: pass diff --git a/pl_bolts/datamodules/binary_mnist_datamodule.py b/pl_bolts/datamodules/binary_mnist_datamodule.py index 4913874b39..2cf73777f3 100644 --- a/pl_bolts/datamodules/binary_mnist_datamodule.py +++ b/pl_bolts/datamodules/binary_mnist_datamodule.py @@ -8,7 +8,7 @@ from torchvision import transforms as transform_lib from torchvision.datasets import MNIST from pl_bolts.datamodules.mnist_dataset import BinaryMNIST -except ImportError: +except ModuleNotFoundError: warn('You want to use `torchvision` which is not installed yet,' # pragma: no-cover ' install it with `pip install torchvision`.') _TORCHVISION_AVAILABLE = False @@ -65,7 +65,9 @@ def __init__( super().__init__(*args, **kwargs) if not _TORCHVISION_AVAILABLE: - raise ImportError('You want to use MNIST dataset loaded from `torchvision` which is not installed yet.') + raise ModuleNotFoundError( # pragma: no-cover + 'You want to use MNIST dataset loaded from `torchvision` which is not installed yet.' + ) self.dims = (1, 28, 28) self.data_dir = data_dir diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py index cf7e67377f..53289ca639 100644 --- a/pl_bolts/datamodules/cifar10_datamodule.py +++ b/pl_bolts/datamodules/cifar10_datamodule.py @@ -13,7 +13,7 @@ from torchvision import transforms as transform_lib from torchvision.datasets import CIFAR10 -except ImportError: +except ModuleNotFoundError: warn('You want to use `torchvision` which is not installed yet,' # pragma: no-cover ' install it with `pip install torchvision`.') _TORCHVISION_AVAILABLE = False @@ -84,7 +84,9 @@ def __init__( super().__init__(*args, **kwargs) if not _TORCHVISION_AVAILABLE: - raise ImportError('You want to use CIFAR10 dataset loaded from `torchvision` which is not installed yet.') + raise ModuleNotFoundError( # pragma: no-cover + 'You want to use CIFAR10 dataset loaded from `torchvision` which is not installed yet.' + ) self.dims = (3, 32, 32) self.DATASET = CIFAR10 diff --git a/pl_bolts/datamodules/cifar10_dataset.py b/pl_bolts/datamodules/cifar10_dataset.py index 5ddb44ab36..5242e50b70 100644 --- a/pl_bolts/datamodules/cifar10_dataset.py +++ b/pl_bolts/datamodules/cifar10_dataset.py @@ -9,7 +9,7 @@ try: from PIL import Image -except ImportError: +except ModuleNotFoundError: warn('You want to use `Pillow` which is not installed yet,' # pragma: no-cover ' install it with `pip install Pillow`.') _PIL_AVAILABLE = False diff --git a/pl_bolts/datamodules/cityscapes_datamodule.py b/pl_bolts/datamodules/cityscapes_datamodule.py index 19cf7d7db9..6983e932ff 100644 --- a/pl_bolts/datamodules/cityscapes_datamodule.py +++ b/pl_bolts/datamodules/cityscapes_datamodule.py @@ -7,7 +7,7 @@ try: from torchvision import transforms as transform_lib from torchvision.datasets import Cityscapes -except ImportError: +except ModuleNotFoundError: warn('You want to use `torchvision` which is not installed yet,' # pragma: no-cover ' install it with `pip install torchvision`.') _TORCHVISION_AVAILABLE = False @@ -77,7 +77,7 @@ def __init__( super().__init__(*args, **kwargs) if not _TORCHVISION_AVAILABLE: - raise ImportError( + raise ModuleNotFoundError( # pragma: no-cover 'You want to use CityScapes dataset loaded from `torchvision` which is not installed yet.' ) diff --git a/pl_bolts/datamodules/fashion_mnist_datamodule.py b/pl_bolts/datamodules/fashion_mnist_datamodule.py index c5544cc109..467bc36c4c 100644 --- a/pl_bolts/datamodules/fashion_mnist_datamodule.py +++ b/pl_bolts/datamodules/fashion_mnist_datamodule.py @@ -7,7 +7,7 @@ try: from torchvision import transforms as transform_lib from torchvision.datasets import FashionMNIST -except ImportError: +except ModuleNotFoundError: warn('You want to use `torchvision` which is not installed yet,' # pragma: no-cover ' install it with `pip install torchvision`.') _TORCHVISION_AVAILABLE = False @@ -63,7 +63,7 @@ def __init__( super().__init__(*args, **kwargs) if not _TORCHVISION_AVAILABLE: - raise ImportError( + raise ModuleNotFoundError( # pragma: no-cover 'You want to use fashion MNIST dataset loaded from `torchvision` which is not installed yet.' ) diff --git a/pl_bolts/datamodules/imagenet_datamodule.py b/pl_bolts/datamodules/imagenet_datamodule.py index 9226e88171..3d458b6160 100644 --- a/pl_bolts/datamodules/imagenet_datamodule.py +++ b/pl_bolts/datamodules/imagenet_datamodule.py @@ -9,7 +9,7 @@ try: from torchvision import transforms as transform_lib from pl_bolts.datamodules.imagenet_dataset import UnlabeledImagenet -except ImportError: +except ModuleNotFoundError: warn('You want to use `torchvision` which is not installed yet,' # pragma: no-cover ' install it with `pip install torchvision`.') _TORCHVISION_AVAILABLE = False @@ -72,7 +72,9 @@ def __init__( super().__init__(*args, **kwargs) if not _TORCHVISION_AVAILABLE: - raise ImportError('You want to use ImageNet dataset loaded from `torchvision` which is not installed yet.') + raise ModuleNotFoundError( # pragma: no-cover + 'You want to use ImageNet dataset loaded from `torchvision` which is not installed yet.' + ) self.image_size = image_size self.dims = (3, self.image_size, self.image_size) diff --git a/pl_bolts/datamodules/imagenet_dataset.py b/pl_bolts/datamodules/imagenet_dataset.py index eca8513d6a..6238e41a7c 100644 --- a/pl_bolts/datamodules/imagenet_dataset.py +++ b/pl_bolts/datamodules/imagenet_dataset.py @@ -13,16 +13,17 @@ try: from sklearn.utils import shuffle -except ImportError: +except ModuleNotFoundError: warn('You want to use `sklearn` which is not installed yet,' # pragma: no-cover ' install it with `pip install sklearn`.') try: from torchvision.datasets import ImageNet from torchvision.datasets.imagenet import load_meta_file -except ImportError: - raise ImportError('You want to use `torchvision` which is not installed yet,' # pragma: no-cover - ' install it with `pip install torchvision`.') +except ModuleNotFoundError: + raise ModuleNotFoundError( # pragma: no-cover + 'You want to use `torchvision` which is not installed yet, install it with `pip install torchvision`.' + ) class UnlabeledImagenet(ImageNet): diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py index e67e42ed88..d1964e8346 100644 --- a/pl_bolts/datamodules/mnist_datamodule.py +++ b/pl_bolts/datamodules/mnist_datamodule.py @@ -7,7 +7,7 @@ try: from torchvision import transforms as transform_lib from torchvision.datasets import MNIST -except ImportError: +except ModuleNotFoundError: warn('You want to use `torchvision` which is not installed yet,' # pragma: no-cover ' install it with `pip install torchvision`.') _TORCHVISION_AVAILABLE = False @@ -65,7 +65,9 @@ def __init__( super().__init__(*args, **kwargs) if not _TORCHVISION_AVAILABLE: - raise ImportError('You want to use MNIST dataset loaded from `torchvision` which is not installed yet.') + raise ModuleNotFoundError( # pragma: no-cover + 'You want to use MNIST dataset loaded from `torchvision` which is not installed yet.' + ) self.dims = (1, 28, 28) self.data_dir = data_dir diff --git a/pl_bolts/datamodules/mnist_dataset.py b/pl_bolts/datamodules/mnist_dataset.py index ac6b2f4694..4ecd00c954 100644 --- a/pl_bolts/datamodules/mnist_dataset.py +++ b/pl_bolts/datamodules/mnist_dataset.py @@ -3,13 +3,14 @@ try: from torchvision import transforms as transform_lib from torchvision.datasets import MNIST -except ImportError: - raise ImportError('You want to use `torchvision` which is not installed yet,' # pragma: no-cover - ' install it with `pip install torchvision`.') +except ModuleNotFoundError: + raise ModuleNotFoundError( # pragma: no-cover + 'You want to use `torchvision` which is not installed yet, install it with `pip install torchvision`.' + ) try: from PIL import Image -except ImportError: +except ModuleNotFoundError: warn('You want to use `Pillow` which is not installed yet,' # pragma: no-cover ' install it with `pip install Pillow`.') _PIL_AVAILABLE = False diff --git a/pl_bolts/datamodules/sklearn_datamodule.py b/pl_bolts/datamodules/sklearn_datamodule.py index 7bb57b93a5..704d629c1b 100644 --- a/pl_bolts/datamodules/sklearn_datamodule.py +++ b/pl_bolts/datamodules/sklearn_datamodule.py @@ -9,7 +9,7 @@ try: from sklearn.utils import shuffle as sk_shuffle -except ImportError: +except ModuleNotFoundError: warn('You want to use `sklearn` which is not installed yet,' # pragma: no-cover ' install it with `pip install sklearn`.') _SKLEARN_AVAILABLE = False @@ -162,7 +162,9 @@ def __init__( if shuffle and _SKLEARN_AVAILABLE: X, y = sk_shuffle(X, y, random_state=random_state) elif shuffle and not _SKLEARN_AVAILABLE: - raise ImportError('You want to use shuffle function from `scikit-learn` which is not installed yet.') + raise ModuleNotFoundError( # pragma: no-cover + 'You want to use shuffle function from `scikit-learn` which is not installed yet.' + ) val_split = 0 if x_val is not None or y_val is not None else val_split test_split = 0 if x_test is not None or y_test is not None else test_split diff --git a/pl_bolts/datamodules/ssl_amdim_datasets.py b/pl_bolts/datamodules/ssl_amdim_datasets.py index bb660d5426..8fdcba8b62 100644 --- a/pl_bolts/datamodules/ssl_amdim_datasets.py +++ b/pl_bolts/datamodules/ssl_amdim_datasets.py @@ -6,13 +6,13 @@ try: from sklearn.utils import shuffle -except ImportError: +except ModuleNotFoundError: warn('You want to use `sklearn` which is not installed yet,' # pragma: no-cover ' install it with `pip install sklearn`.') try: from torchvision.datasets import CIFAR10 -except ImportError: +except ModuleNotFoundError: warn('You want to use `torchvision` which is not installed yet,' # pragma: no-cover ' install it with `pip install torchvision`.') diff --git a/pl_bolts/datamodules/ssl_imagenet_datamodule.py b/pl_bolts/datamodules/ssl_imagenet_datamodule.py index 64ef47aefb..56196dae6e 100644 --- a/pl_bolts/datamodules/ssl_imagenet_datamodule.py +++ b/pl_bolts/datamodules/ssl_imagenet_datamodule.py @@ -10,7 +10,7 @@ try: from torchvision import transforms as transform_lib -except ImportError: +except ModuleNotFoundError: warn('You want to use `torchvision` which is not installed yet,' # pragma: no-cover ' install it with `pip install torchvision`.') _TORCHVISION_AVAILABLE = False @@ -33,7 +33,9 @@ def __init__( super().__init__(*args, **kwargs) if not _TORCHVISION_AVAILABLE: - raise ImportError('You want to use ImageNet dataset loaded from `torchvision` which is not installed yet.') + raise ModuleNotFoundError( # pragma: no-cover + 'You want to use ImageNet dataset loaded from `torchvision` which is not installed yet.' + ) self.data_dir = data_dir self.num_workers = num_workers diff --git a/pl_bolts/datamodules/stl10_datamodule.py b/pl_bolts/datamodules/stl10_datamodule.py index e8a9d9f87b..46885b25df 100644 --- a/pl_bolts/datamodules/stl10_datamodule.py +++ b/pl_bolts/datamodules/stl10_datamodule.py @@ -11,7 +11,7 @@ try: from torchvision import transforms as transform_lib from torchvision.datasets import STL10 -except ImportError: +except ModuleNotFoundError: warn('You want to use `torchvision` which is not installed yet,' # pragma: no-cover ' install it with `pip install torchvision`.') _TORCHVISION_AVAILABLE = False @@ -75,7 +75,9 @@ def __init__( super().__init__(*args, **kwargs) if not _TORCHVISION_AVAILABLE: - raise ImportError('You want to use STL10 dataset loaded from `torchvision` which is not installed yet.') + raise ModuleNotFoundError( # pragma: no-cover + 'You want to use STL10 dataset loaded from `torchvision` which is not installed yet.' + ) self.dims = (3, 96, 96) self.data_dir = data_dir if data_dir is not None else os.getcwd() diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py index 6ed5687dc7..c17d5fedfe 100644 --- a/pl_bolts/datamodules/vocdetection_datamodule.py +++ b/pl_bolts/datamodules/vocdetection_datamodule.py @@ -8,7 +8,7 @@ try: from torchvision.datasets import VOCDetection -except ImportError: +except ModuleNotFoundError: warn('You want to use `torchvision` which is not installed yet,' # pragma: no-cover ' install it with `pip install torchvision`.') _TORCHVISION_AVAILABLE = False @@ -120,7 +120,9 @@ def __init__( super().__init__(*args, **kwargs) if not _TORCHVISION_AVAILABLE: - raise ImportError('You want to use VOC dataset loaded from `torchvision` which is not installed yet.') + raise ModuleNotFoundError( # pragma: no-cover + 'You want to use VOC dataset loaded from `torchvision` which is not installed yet.' + ) self.year = year self.data_dir = data_dir diff --git a/pl_bolts/models/detection/__init__.py b/pl_bolts/models/detection/__init__.py index b739fe3c3e..205420bcf1 100644 --- a/pl_bolts/models/detection/__init__.py +++ b/pl_bolts/models/detection/__init__.py @@ -2,7 +2,7 @@ try: from pl_bolts.models.detection.faster_rcnn import FasterRCNN -except ImportError: # pragma: no-cover +except ModuleNotFoundError: # pragma: no-cover pass # pragma: no-cover else: __all__.append('FasterRCNN') diff --git a/pl_bolts/models/detection/faster_rcnn.py b/pl_bolts/models/detection/faster_rcnn.py index 910dad1a53..4a6173f982 100644 --- a/pl_bolts/models/detection/faster_rcnn.py +++ b/pl_bolts/models/detection/faster_rcnn.py @@ -7,7 +7,7 @@ try: from torchvision.models.detection import faster_rcnn, fasterrcnn_resnet50_fpn from torchvision.ops import box_iou -except ImportError: +except ModuleNotFoundError: warn('You want to use `torchvision` which is not installed yet,' # pragma: no-cover ' install it with `pip install torchvision`.') diff --git a/pl_bolts/models/mnist_module.py b/pl_bolts/models/mnist_module.py index 3dc71b22b1..d9521e1db8 100644 --- a/pl_bolts/models/mnist_module.py +++ b/pl_bolts/models/mnist_module.py @@ -10,7 +10,7 @@ try: from torchvision import transforms from torchvision.datasets import MNIST -except ImportError: +except ModuleNotFoundError: warn('You want to use `torchvision` which is not installed yet,' # pragma: no-cover ' install it with `pip install torchvision`.') diff --git a/pl_bolts/models/regression/linear_regression.py b/pl_bolts/models/regression/linear_regression.py index 202ade6f28..f69fb7020e 100644 --- a/pl_bolts/models/regression/linear_regression.py +++ b/pl_bolts/models/regression/linear_regression.py @@ -126,9 +126,10 @@ def cli_main(): # create dataset try: from sklearn.datasets import load_boston - except ImportError: - raise ImportError('You want to use `sklearn` which is not installed yet,' # pragma: no-cover - ' install it with `pip install sklearn`.') + except ModuleNotFoundError: + raise ModuleNotFoundError( # pragma: no-cover + 'You want to use `sklearn` which is not installed yet, install it with `pip install sklearn`.' + ) X, y = load_boston(return_X_y=True) # these are numpy arrays loaders = SklearnDataModule(X, y) diff --git a/pl_bolts/models/regression/logistic_regression.py b/pl_bolts/models/regression/logistic_regression.py index 79626cd143..047c05747d 100644 --- a/pl_bolts/models/regression/logistic_regression.py +++ b/pl_bolts/models/regression/logistic_regression.py @@ -132,9 +132,10 @@ def cli_main(): # Example: Iris dataset in Sklearn (4 features, 3 class labels) try: from sklearn.datasets import load_iris - except ImportError: - raise ImportError('You want to use `sklearn` which is not installed yet,' # pragma: no-cover - ' install it with `pip install sklearn`.') + except ModuleNotFoundError: + raise ModuleNotFoundError( # pragma: no-cover + 'You want to use `sklearn` which is not installed yet, install it with `pip install sklearn`.' + ) X, y = load_iris(return_X_y=True) loaders = SklearnDataModule(X, y) diff --git a/pl_bolts/models/self_supervised/amdim/__init__.py b/pl_bolts/models/self_supervised/amdim/__init__.py index c3ccf8b3f2..2e30ad15d8 100644 --- a/pl_bolts/models/self_supervised/amdim/__init__.py +++ b/pl_bolts/models/self_supervised/amdim/__init__.py @@ -9,5 +9,5 @@ AMDIMTrainTransformsImageNet128, AMDIMEvalTransformsImageNet128, ) -except ImportError: +except ModuleNotFoundError: pass diff --git a/pl_bolts/models/self_supervised/amdim/amdim_module.py b/pl_bolts/models/self_supervised/amdim/amdim_module.py index 0190b2c1bb..c098801907 100644 --- a/pl_bolts/models/self_supervised/amdim/amdim_module.py +++ b/pl_bolts/models/self_supervised/amdim/amdim_module.py @@ -9,7 +9,7 @@ try: from pl_bolts.models.self_supervised.amdim.datasets import AMDIMPretraining -except ImportError: +except ModuleNotFoundError: pass from pl_bolts.losses.self_supervised_learning import FeatureMapContrastiveTask from pl_bolts.models.self_supervised.amdim.networks import AMDIMEncoder diff --git a/pl_bolts/models/self_supervised/amdim/datasets.py b/pl_bolts/models/self_supervised/amdim/datasets.py index 993a47c56c..037c8bf4be 100644 --- a/pl_bolts/models/self_supervised/amdim/datasets.py +++ b/pl_bolts/models/self_supervised/amdim/datasets.py @@ -7,7 +7,7 @@ from pl_bolts.datamodules.imagenet_dataset import UnlabeledImagenet from pl_bolts.datamodules.ssl_amdim_datasets import CIFAR10Mixed from pl_bolts.models.self_supervised.amdim import transforms as amdim_transforms -except ImportError: +except ModuleNotFoundError: warn('You want to use `torchvision` which is not installed yet,' # pragma: no-cover ' install it with `pip install torchvision`.') diff --git a/pl_bolts/models/self_supervised/amdim/transforms.py b/pl_bolts/models/self_supervised/amdim/transforms.py index 0ff5f232c1..bb146c338d 100644 --- a/pl_bolts/models/self_supervised/amdim/transforms.py +++ b/pl_bolts/models/self_supervised/amdim/transforms.py @@ -4,7 +4,7 @@ try: from torchvision import transforms -except ImportError: +except ModuleNotFoundError: warn('You want to use `torchvision` which is not installed yet,' # pragma: no-cover ' install it with `pip install torchvision`.') _TORCHVISION_AVAILABLE = False @@ -34,7 +34,9 @@ def __init__(self): """ if not _TORCHVISION_AVAILABLE: - raise ImportError('You want to use `transforms` from `torchvision` which is not installed yet.') + raise ModuleNotFoundError( # pragma: no-cover + 'You want to use `transforms` from `torchvision` which is not installed yet.' + ) # flipping image along vertical axis self.flip_lr = transforms.RandomHorizontalFlip(p=0.5) @@ -79,7 +81,9 @@ def __init__(self): (view1, view2) = transform(x) """ if not _TORCHVISION_AVAILABLE: - raise ImportError('You want to use `transforms` from `torchvision` which is not installed yet.') + raise ModuleNotFoundError( # pragma: no-cover + 'You want to use `transforms` from `torchvision` which is not installed yet.' + ) # flipping image along vertical axis self.flip_lr = transforms.RandomHorizontalFlip(p=0.5) @@ -119,7 +123,9 @@ def __init__(self, height=64): (view1, view2) = transform(x) """ if not _TORCHVISION_AVAILABLE: - raise ImportError('You want to use `transforms` from `torchvision` which is not installed yet.') + raise ModuleNotFoundError( # pragma: no-cover + 'You want to use `transforms` from `torchvision` which is not installed yet.' + ) # flipping image along vertical axis self.flip_lr = transforms.RandomHorizontalFlip(p=0.5) @@ -165,7 +171,9 @@ def __init__(self, height=64): view1 = transform(x) """ if not _TORCHVISION_AVAILABLE: - raise ImportError('You want to use `transforms` from `torchvision` which is not installed yet.') + raise ModuleNotFoundError( # pragma: no-cover + 'You want to use `transforms` from `torchvision` which is not installed yet.' + ) # flipping image along vertical axis self.flip_lr = transforms.RandomHorizontalFlip(p=0.5) @@ -207,7 +215,9 @@ def __init__(self, height=128): (view1, view2) = transform(x) """ if not _TORCHVISION_AVAILABLE: - raise ImportError('You want to use `transforms` from `torchvision` which is not installed yet.') + raise ModuleNotFoundError( # pragma: no-cover + 'You want to use `transforms` from `torchvision` which is not installed yet.' + ) # image augmentation functions self.flip_lr = transforms.RandomHorizontalFlip(p=0.5) @@ -253,7 +263,9 @@ def __init__(self, height=128): view1 = transform(x) """ if not _TORCHVISION_AVAILABLE: - raise ImportError('You want to use `transforms` from `torchvision` which is not installed yet.') + raise ModuleNotFoundError( # pragma: no-cover + 'You want to use `transforms` from `torchvision` which is not installed yet.' + ) # image augmentation functions self.flip_lr = transforms.RandomHorizontalFlip(p=0.5) diff --git a/pl_bolts/models/self_supervised/cpc/cpc_module.py b/pl_bolts/models/self_supervised/cpc/cpc_module.py index e6aa3b5b0c..84fe1be0e4 100644 --- a/pl_bolts/models/self_supervised/cpc/cpc_module.py +++ b/pl_bolts/models/self_supervised/cpc/cpc_module.py @@ -168,9 +168,10 @@ def validation_step(self, batch, batch_nb): def shared_step(self, batch): try: from pl_bolts.datamodules.stl10_datamodule import STL10DataModule - except ImportError: - raise ImportError('You want to use `torchvision` which is not installed yet,' # pragma: no-cover - ' install it with `pip install torchvision`.') + except ModuleNotFoundError: + raise ModuleNotFoundError( # pragma: no-cover + 'You want to use `torchvision` which is not installed yet, install it with `pip install torchvision`.' + ) if isinstance(self.datamodule, STL10DataModule): unlabeled_batch = batch[0] diff --git a/pl_bolts/models/self_supervised/cpc/transforms.py b/pl_bolts/models/self_supervised/cpc/transforms.py index d08525f856..fc870a0251 100644 --- a/pl_bolts/models/self_supervised/cpc/transforms.py +++ b/pl_bolts/models/self_supervised/cpc/transforms.py @@ -4,7 +4,7 @@ try: from torchvision import transforms -except ImportError: +except ModuleNotFoundError: warn('You want to use `torchvision` which is not installed yet,' # pragma: no-cover ' install it with `pip install torchvision`.') _TORCHVISION_AVAILABLE = False @@ -44,7 +44,9 @@ def __init__(self, patch_size=8, overlap=4): """ if not _TORCHVISION_AVAILABLE: - raise ImportError('You want to use `transforms` from `torchvision` which is not installed yet.') + raise ModuleNotFoundError( # pragma: no-cover + 'You want to use `transforms` from `torchvision` which is not installed yet.' + ) self.patch_size = patch_size self.overlap = overlap @@ -100,7 +102,9 @@ def __init__(self, patch_size=8, overlap=4): """ if not _TORCHVISION_AVAILABLE: - raise ImportError('You want to use `transforms` from `torchvision` which is not installed yet.') + raise ModuleNotFoundError( # pragma: no-cover + 'You want to use `transforms` from `torchvision` which is not installed yet.' + ) # flipping image along vertical axis self.patch_size = patch_size @@ -154,7 +158,9 @@ def __init__(self, patch_size=16, overlap=8): """ if not _TORCHVISION_AVAILABLE: - raise ImportError('You want to use `transforms` from `torchvision` which is not installed yet.') + raise ModuleNotFoundError( # pragma: no-cover + 'You want to use `transforms` from `torchvision` which is not installed yet.' + ) # flipping image along vertical axis self.patch_size = patch_size @@ -211,7 +217,9 @@ def __init__(self, patch_size=16, overlap=8): """ if not _TORCHVISION_AVAILABLE: - raise ImportError('You want to use `transforms` from `torchvision` which is not installed yet.') + raise ModuleNotFoundError( # pragma: no-cover + 'You want to use `transforms` from `torchvision` which is not installed yet.' + ) # flipping image along vertical axis self.patch_size = patch_size @@ -259,7 +267,9 @@ def __init__(self, patch_size=32, overlap=16): train_loader = module.train_dataloader(batch_size=32, transforms=CPCTrainTransformsImageNet128()) """ if not _TORCHVISION_AVAILABLE: - raise ImportError('You want to use `transforms` from `torchvision` which is not installed yet.') + raise ModuleNotFoundError( # pragma: no-cover + 'You want to use `transforms` from `torchvision` which is not installed yet.' + ) # image augmentation functions self.patch_size = patch_size @@ -316,7 +326,9 @@ def __init__(self, patch_size=32, overlap=16): train_loader = module.train_dataloader(batch_size=32, transforms=CPCEvalTransformsImageNet128()) """ if not _TORCHVISION_AVAILABLE: - raise ImportError('You want to use `transforms` from `torchvision` which is not installed yet.') + raise ModuleNotFoundError( # pragma: no-cover + 'You want to use `transforms` from `torchvision` which is not installed yet.' + ) # image augmentation functions self.patch_size = patch_size diff --git a/pl_bolts/models/self_supervised/moco/moco2_module.py b/pl_bolts/models/self_supervised/moco/moco2_module.py index d7df748e83..8eeb72c5cf 100644 --- a/pl_bolts/models/self_supervised/moco/moco2_module.py +++ b/pl_bolts/models/self_supervised/moco/moco2_module.py @@ -16,7 +16,7 @@ try: import torchvision -except ImportError: +except ModuleNotFoundError: warn('You want to use `torchvision` which is not installed yet,' # pragma: no-cover ' install it with `pip install torchvision`.') diff --git a/pl_bolts/models/self_supervised/moco/transforms.py b/pl_bolts/models/self_supervised/moco/transforms.py index 5f2b313c8d..39acda8ef8 100644 --- a/pl_bolts/models/self_supervised/moco/transforms.py +++ b/pl_bolts/models/self_supervised/moco/transforms.py @@ -6,7 +6,7 @@ try: from torchvision import transforms -except ImportError: +except ModuleNotFoundError: warn('You want to use `torchvision` which is not installed yet,' # pragma: no-cover ' install it with `pip install torchvision`.') _TORCHVISION_AVAILABLE = False @@ -15,7 +15,7 @@ try: from PIL import ImageFilter -except ImportError: +except ModuleNotFoundError: warn('You want to use `Pillow` which is not installed yet,' # pragma: no-cover ' install it with `pip install Pillow`.') _PIL_AVAILABLE = False @@ -31,7 +31,9 @@ class Moco2TrainCIFAR10Transforms: """ def __init__(self, height=32): if not _TORCHVISION_AVAILABLE: - raise ImportError('You want to use `transforms` from `torchvision` which is not installed yet.') + raise ModuleNotFoundError( # pragma: no-cover + 'You want to use `transforms` from `torchvision` which is not installed yet.' + ) # image augmentation functions self.train_transform = transforms.Compose([ @@ -60,7 +62,9 @@ class Moco2EvalCIFAR10Transforms: """ def __init__(self, height=32): if not _TORCHVISION_AVAILABLE: - raise ImportError('You want to use `transforms` from `torchvision` which is not installed yet.') + raise ModuleNotFoundError( # pragma: no-cover + 'You want to use `transforms` from `torchvision` which is not installed yet.' + ) self.test_transform = transforms.Compose([ transforms.Resize(height + 12), @@ -82,7 +86,9 @@ class Moco2TrainSTL10Transforms: """ def __init__(self, height=64): if not _TORCHVISION_AVAILABLE: - raise ImportError('You want to use `transforms` from `torchvision` which is not installed yet.') + raise ModuleNotFoundError( # pragma: no-cover + 'You want to use `transforms` from `torchvision` which is not installed yet.' + ) # image augmentation functions self.train_transform = transforms.Compose([ @@ -110,7 +116,9 @@ class Moco2EvalSTL10Transforms: """ def __init__(self, height=64): if not _TORCHVISION_AVAILABLE: - raise ImportError('You want to use `transforms` from `torchvision` which is not installed yet.') + raise ModuleNotFoundError( # pragma: no-cover + 'You want to use `transforms` from `torchvision` which is not installed yet.' + ) self.test_augmentation = transforms.Compose([ transforms.Resize(height + 11), @@ -134,7 +142,9 @@ class Moco2TrainImagenetTransforms: def __init__(self, height=128): if not _TORCHVISION_AVAILABLE: - raise ImportError('You want to use `transforms` from `torchvision` which is not installed yet.') + raise ModuleNotFoundError( # pragma: no-cover + 'You want to use `transforms` from `torchvision` which is not installed yet.' + ) # image augmentation functions self.train_transform = transforms.Compose([ @@ -163,7 +173,9 @@ class Moco2EvalImagenetTransforms: """ def __init__(self, height=128): if not _TORCHVISION_AVAILABLE: - raise ImportError('You want to use `transforms` from `torchvision` which is not installed yet.') + raise ModuleNotFoundError( # pragma: no-cover + 'You want to use `transforms` from `torchvision` which is not installed yet.' + ) self.test_transform = transforms.Compose([ transforms.Resize(height + 32), @@ -183,8 +195,9 @@ class GaussianBlur(object): def __init__(self, sigma=(0.1, 2.0)): if not _PIL_AVAILABLE: - raise ImportError('You want to use `Pillow` which is not installed yet,' - ' install it with `pip install Pillow`.') + raise ModuleNotFoundError( # pragma: no-cover + 'You want to use `Pillow` which is not installed yet, install it with `pip install Pillow`.' + ) self.sigma = sigma def __call__(self, x): diff --git a/pl_bolts/models/self_supervised/resnets.py b/pl_bolts/models/self_supervised/resnets.py index a82f161c9b..6cd3e5f683 100644 --- a/pl_bolts/models/self_supervised/resnets.py +++ b/pl_bolts/models/self_supervised/resnets.py @@ -4,7 +4,7 @@ try: from torchvision.models.utils import load_state_dict_from_url -except ImportError: +except ModuleNotFoundError: warn('You want to use `torchvision` which is not installed yet,' # pragma: no-cover ' install it with `pip install torchvision`.') diff --git a/pl_bolts/models/self_supervised/simclr/simclr_module.py b/pl_bolts/models/self_supervised/simclr/simclr_module.py index 582883991a..11e410b15b 100644 --- a/pl_bolts/models/self_supervised/simclr/simclr_module.py +++ b/pl_bolts/models/self_supervised/simclr/simclr_module.py @@ -8,7 +8,7 @@ try: from torchvision.models import densenet -except ImportError: +except ModuleNotFoundError: warn('You want to use `torchvision` which is not installed yet,' # pragma: no-cover ' install it with `pip install torchvision`.') diff --git a/pl_bolts/models/self_supervised/simclr/transforms.py b/pl_bolts/models/self_supervised/simclr/transforms.py index cc419f3a44..55b975c66f 100644 --- a/pl_bolts/models/self_supervised/simclr/transforms.py +++ b/pl_bolts/models/self_supervised/simclr/transforms.py @@ -4,7 +4,7 @@ try: import torchvision.transforms as transforms -except ImportError: +except ModuleNotFoundError: warn('You want to use `torchvision` which is not installed yet,' # pragma: no-cover ' install it with `pip install torchvision`.') _TORCHVISION_AVAILABLE = False @@ -13,7 +13,7 @@ try: import cv2 -except ImportError: +except ModuleNotFoundError: warn('You want to use `opencv-python` which is not installed yet,' # pragma: no-cover ' install it with `pip install opencv-python`.') @@ -41,7 +41,9 @@ class SimCLRTrainDataTransform(object): """ def __init__(self, input_height, s=1): if not _TORCHVISION_AVAILABLE: - raise ImportError('You want to use `transforms` from `torchvision` which is not installed yet.') + raise ModuleNotFoundError( # pragma: no-cover + 'You want to use `transforms` from `torchvision` which is not installed yet.' + ) self.s = s self.input_height = input_height @@ -81,7 +83,9 @@ class SimCLREvalDataTransform(object): """ def __init__(self, input_height, s=1): if not _TORCHVISION_AVAILABLE: - raise ImportError('You want to use `transforms` from `torchvision` which is not installed yet.') + raise ModuleNotFoundError( # pragma: no-cover + 'You want to use `transforms` from `torchvision` which is not installed yet.' + ) self.s = s self.input_height = input_height @@ -102,7 +106,9 @@ class GaussianBlur(object): # Implements Gaussian blur as described in the SimCLR paper def __init__(self, kernel_size, min=0.1, max=2.0): if not _TORCHVISION_AVAILABLE: - raise ImportError('You want to use `transforms` from `torchvision` which is not installed yet.') + raise ModuleNotFoundError( # pragma: no-cover + 'You want to use `transforms` from `torchvision` which is not installed yet.' + ) self.min = min self.max = max diff --git a/pl_bolts/transforms/dataset_normalizations.py b/pl_bolts/transforms/dataset_normalizations.py index 0777b5833d..12115c3b1f 100644 --- a/pl_bolts/transforms/dataset_normalizations.py +++ b/pl_bolts/transforms/dataset_normalizations.py @@ -1,7 +1,7 @@ from warnings import warn try: from torchvision import transforms -except ImportError: +except ModuleNotFoundError: warn('You want to use `torchvision` which is not installed yet,' # pragma: no-cover ' install it with `pip install torchvision`.') diff --git a/pl_bolts/transforms/self_supervised/ssl_transforms.py b/pl_bolts/transforms/self_supervised/ssl_transforms.py index 131e1163e5..70e6e762ea 100644 --- a/pl_bolts/transforms/self_supervised/ssl_transforms.py +++ b/pl_bolts/transforms/self_supervised/ssl_transforms.py @@ -5,7 +5,7 @@ try: from PIL import Image -except ImportError: +except ModuleNotFoundError: warn('You want to use `Pillow` which is not installed yet,' # pragma: no-cover ' install it with `pip install Pillow`.') _PIL_AVAILABLE = False @@ -27,8 +27,9 @@ def __init__(self, max_translation): def __call__(self, old_image): if not _PIL_AVAILABLE: - raise ImportError('You want to use `Pillow` which is not installed yet,' - ' install it with `pip install Pillow`.') + raise ModuleNotFoundError( # pragma: no-cover + 'You want to use `Pillow` which is not installed yet, install it with `pip install Pillow`.' + ) xtranslation, ytranslation = np.random.randint(-self.max_translation, self.max_translation + 1, size=2) diff --git a/pl_bolts/utils/semi_supervised.py b/pl_bolts/utils/semi_supervised.py index 4192ba0877..b6eb64e0da 100644 --- a/pl_bolts/utils/semi_supervised.py +++ b/pl_bolts/utils/semi_supervised.py @@ -6,7 +6,7 @@ try: from sklearn.utils import shuffle as sk_shuffle -except ImportError: +except ModuleNotFoundError: warn('You want to use `sklearn` which is not installed yet,' # pragma: no-cover ' install it with `pip install sklearn`.') diff --git a/tests/models/self_supervised/test_models.py b/tests/models/self_supervised/test_models.py index c432d76744..b2df8af307 100644 --- a/tests/models/self_supervised/test_models.py +++ b/tests/models/self_supervised/test_models.py @@ -1,4 +1,6 @@ +import pytest import pytorch_lightning as pl +import torch from pytorch_lightning import seed_everything from pl_bolts.datamodules import CIFAR10DataModule @@ -9,6 +11,8 @@ from pl_bolts.models.self_supervised.simclr.transforms import SimCLREvalDataTransform, SimCLRTrainDataTransform +# TODO: this test is hanging (runs for more then 10min) so we need to use GPU or optimize it... +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") def test_cpcv2(tmpdir): seed_everything() @@ -24,6 +28,8 @@ def test_cpcv2(tmpdir): assert float(loss) > 0 +# TODO: this test is hanging (runs for more then 10min) so we need to use GPU or optimize it... +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") def test_byol(tmpdir): seed_everything() diff --git a/tests/models/self_supervised/test_scripts.py b/tests/models/self_supervised/test_scripts.py index ce01743cdc..a9dbdc07e7 100644 --- a/tests/models/self_supervised/test_scripts.py +++ b/tests/models/self_supervised/test_scripts.py @@ -1,6 +1,7 @@ from unittest import mock import pytest +import torch @pytest.mark.parametrize('cli_args', ["--max_epochs 1 --max_steps 3 --fast_dev_run --batch_size 2"]) @@ -13,6 +14,8 @@ def test_cli_run_self_supervised_amdim(cli_args): cli_main() +# TODO: this test is hanging (runs for more then 10min) so we need to use GPU or optimize it... +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") @pytest.mark.parametrize('cli_args', ['--max_epochs 1 --max_steps 3 --fast_dev_run --batch_size 2 --encoder resnet18']) def test_cli_run_self_supervised_cpc(cli_args): """Test running CLI for an example with default params.""" diff --git a/tests/models/test_executable_scripts.py b/tests/models/test_executable_scripts.py index 5fcea04897..2f0953f9df 100644 --- a/tests/models/test_executable_scripts.py +++ b/tests/models/test_executable_scripts.py @@ -1,6 +1,7 @@ from unittest import mock import pytest +import torch @pytest.mark.parametrize('cli_args', [ @@ -14,6 +15,8 @@ def test_cli_run_basic_gan(cli_args): cli_main() +# TODO: this test is hanging (runs for more then 10min) so we need to use GPU or optimize it... +@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine") @pytest.mark.parametrize('cli_args', [ '--max_epochs 1 --limit_train_batches 2 --limit_val_batches 2 --batch_size 2 --encoder resnet18', ]) diff --git a/tests/models/test_vision_models.py b/tests/models/test_vision.py similarity index 100% rename from tests/models/test_vision_models.py rename to tests/models/test_vision.py From f49958de9f9ea57eb048bcec90b2454312866356 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 13 Oct 2020 19:37:53 +0200 Subject: [PATCH 19/32] set min requirements PL 1.0 (#274) --- requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements.txt b/requirements.txt index c434a7c377..fbf71902e1 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ -pytorch-lightning>=0.10.0 +pytorch-lightning>=1.0 torch>=1.6 \ No newline at end of file From 40fd35bf17b085c0de3a399dceafa167d6e62375 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Wed, 14 Oct 2020 17:46:46 +0900 Subject: [PATCH 20/32] Use explicit exception chaining (#261) --- pl_bolts/callbacks/vision/confused_logit.py | 4 ++-- pl_bolts/datamodules/base_dataset.py | 4 ++-- pl_bolts/datamodules/imagenet_dataset.py | 4 ++-- pl_bolts/datamodules/mnist_dataset.py | 4 ++-- pl_bolts/models/regression/linear_regression.py | 4 ++-- pl_bolts/models/regression/logistic_regression.py | 4 ++-- pl_bolts/models/self_supervised/cpc/cpc_module.py | 4 ++-- 7 files changed, 14 insertions(+), 14 deletions(-) diff --git a/pl_bolts/callbacks/vision/confused_logit.py b/pl_bolts/callbacks/vision/confused_logit.py index b3b07007f7..3e6f92da50 100644 --- a/pl_bolts/callbacks/vision/confused_logit.py +++ b/pl_bolts/callbacks/vision/confused_logit.py @@ -64,12 +64,12 @@ def on_train_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_id x, y = batch try: logits = pl_module.last_logits - except AttributeError as e: + except AttributeError as err: m = """please track the last_logits in the training_step like so: def training_step(...): self.last_logits = your_logits """ - raise AttributeError(m) + raise AttributeError(m) from err # only check when it has opinions (ie: the logit > 5) if logits.max() > self.min_logit_value: diff --git a/pl_bolts/datamodules/base_dataset.py b/pl_bolts/datamodules/base_dataset.py index 3397778859..a5c1e50898 100644 --- a/pl_bolts/datamodules/base_dataset.py +++ b/pl_bolts/datamodules/base_dataset.py @@ -54,5 +54,5 @@ def _download_from_url(self, base_url: str, data_folder: str, file_name: str): fpath = os.path.join(data_folder, file_name) try: urllib.request.urlretrieve(url, fpath) - except HTTPError: - raise RuntimeError(f'Failed download from {url}') + except HTTPError as err: + raise RuntimeError(f'Failed download from {url}') from err diff --git a/pl_bolts/datamodules/imagenet_dataset.py b/pl_bolts/datamodules/imagenet_dataset.py index 6238e41a7c..4a6cbd6677 100644 --- a/pl_bolts/datamodules/imagenet_dataset.py +++ b/pl_bolts/datamodules/imagenet_dataset.py @@ -20,10 +20,10 @@ try: from torchvision.datasets import ImageNet from torchvision.datasets.imagenet import load_meta_file -except ModuleNotFoundError: +except ModuleNotFoundError as err: raise ModuleNotFoundError( # pragma: no-cover 'You want to use `torchvision` which is not installed yet, install it with `pip install torchvision`.' - ) + ) from err class UnlabeledImagenet(ImageNet): diff --git a/pl_bolts/datamodules/mnist_dataset.py b/pl_bolts/datamodules/mnist_dataset.py index 4ecd00c954..0144f96345 100644 --- a/pl_bolts/datamodules/mnist_dataset.py +++ b/pl_bolts/datamodules/mnist_dataset.py @@ -3,10 +3,10 @@ try: from torchvision import transforms as transform_lib from torchvision.datasets import MNIST -except ModuleNotFoundError: +except ModuleNotFoundError as err: raise ModuleNotFoundError( # pragma: no-cover 'You want to use `torchvision` which is not installed yet, install it with `pip install torchvision`.' - ) + ) from err try: from PIL import Image diff --git a/pl_bolts/models/regression/linear_regression.py b/pl_bolts/models/regression/linear_regression.py index f69fb7020e..241ff9f33e 100644 --- a/pl_bolts/models/regression/linear_regression.py +++ b/pl_bolts/models/regression/linear_regression.py @@ -126,10 +126,10 @@ def cli_main(): # create dataset try: from sklearn.datasets import load_boston - except ModuleNotFoundError: + except ModuleNotFoundError as err: raise ModuleNotFoundError( # pragma: no-cover 'You want to use `sklearn` which is not installed yet, install it with `pip install sklearn`.' - ) + ) from err X, y = load_boston(return_X_y=True) # these are numpy arrays loaders = SklearnDataModule(X, y) diff --git a/pl_bolts/models/regression/logistic_regression.py b/pl_bolts/models/regression/logistic_regression.py index 047c05747d..d39949c5a1 100644 --- a/pl_bolts/models/regression/logistic_regression.py +++ b/pl_bolts/models/regression/logistic_regression.py @@ -132,10 +132,10 @@ def cli_main(): # Example: Iris dataset in Sklearn (4 features, 3 class labels) try: from sklearn.datasets import load_iris - except ModuleNotFoundError: + except ModuleNotFoundError as err: raise ModuleNotFoundError( # pragma: no-cover 'You want to use `sklearn` which is not installed yet, install it with `pip install sklearn`.' - ) + ) from err X, y = load_iris(return_X_y=True) loaders = SklearnDataModule(X, y) diff --git a/pl_bolts/models/self_supervised/cpc/cpc_module.py b/pl_bolts/models/self_supervised/cpc/cpc_module.py index 84fe1be0e4..195f84e781 100644 --- a/pl_bolts/models/self_supervised/cpc/cpc_module.py +++ b/pl_bolts/models/self_supervised/cpc/cpc_module.py @@ -168,10 +168,10 @@ def validation_step(self, batch, batch_nb): def shared_step(self, batch): try: from pl_bolts.datamodules.stl10_datamodule import STL10DataModule - except ModuleNotFoundError: + except ModuleNotFoundError as err: raise ModuleNotFoundError( # pragma: no-cover 'You want to use `torchvision` which is not installed yet, install it with `pip install torchvision`.' - ) + ) from err if isinstance(self.datamodule, STL10DataModule): unlabeled_batch = batch[0] From c1204d572f8d1082b9e8de34226098dd684b1f57 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 14 Oct 2020 19:44:53 +0200 Subject: [PATCH 21/32] move datasets to existing package (#275) * move datasets * CI --- .github/workflows/ci_test-base.yml | 2 +- pl_bolts/datamodules/__init__.py | 2 +- pl_bolts/datamodules/binary_mnist_datamodule.py | 2 +- pl_bolts/datamodules/cifar10_datamodule.py | 2 +- pl_bolts/datamodules/imagenet_datamodule.py | 2 +- pl_bolts/datamodules/kitti_datamodule.py | 2 +- pl_bolts/datamodules/ssl_imagenet_datamodule.py | 2 +- pl_bolts/datamodules/stl10_datamodule.py | 2 +- pl_bolts/{datamodules => datasets}/base_dataset.py | 0 pl_bolts/{datamodules => datasets}/cifar10_dataset.py | 2 +- pl_bolts/{datamodules => datasets}/concat_dataset.py | 0 pl_bolts/{datamodules => datasets}/imagenet_dataset.py | 0 pl_bolts/{datamodules => datasets}/kitti_dataset.py | 0 pl_bolts/{datamodules => datasets}/mnist_dataset.py | 0 pl_bolts/{datamodules => datasets}/ssl_amdim_datasets.py | 0 pl_bolts/models/self_supervised/amdim/datasets.py | 4 ++-- tests/datamodules/test_dataloader.py | 2 +- tests/datamodules/test_datamodules.py | 2 +- 18 files changed, 13 insertions(+), 13 deletions(-) rename pl_bolts/{datamodules => datasets}/base_dataset.py (100%) rename pl_bolts/{datamodules => datasets}/cifar10_dataset.py (99%) rename pl_bolts/{datamodules => datasets}/concat_dataset.py (100%) rename pl_bolts/{datamodules => datasets}/imagenet_dataset.py (100%) rename pl_bolts/{datamodules => datasets}/kitti_dataset.py (100%) rename pl_bolts/{datamodules => datasets}/mnist_dataset.py (100%) rename pl_bolts/{datamodules => datasets}/ssl_amdim_datasets.py (100%) diff --git a/.github/workflows/ci_test-base.yml b/.github/workflows/ci_test-base.yml index a6d808a816..627789dc57 100644 --- a/.github/workflows/ci_test-base.yml +++ b/.github/workflows/ci_test-base.yml @@ -66,7 +66,7 @@ jobs: - name: Test Package [only] run: | # NOTE: run coverage on tests does not propagare faler status for Win, https://github.com/nedbat/coveragepy/issues/1003 - coverage run --source pl_bolts -m pytest pl_bolts -v --junitxml=junit/test-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}.xml --ignore=pl_bolts/datamodules --ignore=pl_bolts/models/self_supervised/amdim/transforms.py --ignore=pl_bolts/models/rl + coverage run --source pl_bolts -m pytest pl_bolts -v --junitxml=junit/test-results-${{ runner.os }}-${{ matrix.python-version }}-${{ matrix.requires }}.xml --ignore=pl_bolts/datamodules --ignore=pl_bolts/datasets --ignore=pl_bolts/models/self_supervised/amdim/transforms.py --ignore=pl_bolts/models/rl - name: Upload pytest test results uses: actions/upload-artifact@master diff --git a/pl_bolts/datamodules/__init__.py b/pl_bolts/datamodules/__init__.py index cb810e1b6e..dd2d0423f4 100644 --- a/pl_bolts/datamodules/__init__.py +++ b/pl_bolts/datamodules/__init__.py @@ -23,7 +23,7 @@ from pl_bolts.datamodules.stl10_datamodule import STL10DataModule from pl_bolts.datamodules.vocdetection_datamodule import VOCDetectionDataModule - from pl_bolts.datamodules.kitti_dataset import KittiDataset + from pl_bolts.datasets.kitti_dataset import KittiDataset from pl_bolts.datamodules.kitti_datamodule import KittiDataModule except ModuleNotFoundError: pass diff --git a/pl_bolts/datamodules/binary_mnist_datamodule.py b/pl_bolts/datamodules/binary_mnist_datamodule.py index 2cf73777f3..ac5c5467d0 100644 --- a/pl_bolts/datamodules/binary_mnist_datamodule.py +++ b/pl_bolts/datamodules/binary_mnist_datamodule.py @@ -7,7 +7,7 @@ try: from torchvision import transforms as transform_lib from torchvision.datasets import MNIST - from pl_bolts.datamodules.mnist_dataset import BinaryMNIST + from pl_bolts.datasets.mnist_dataset import BinaryMNIST except ModuleNotFoundError: warn('You want to use `torchvision` which is not installed yet,' # pragma: no-cover ' install it with `pip install torchvision`.') diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py index 53289ca639..80d44c17df 100644 --- a/pl_bolts/datamodules/cifar10_datamodule.py +++ b/pl_bolts/datamodules/cifar10_datamodule.py @@ -6,7 +6,7 @@ from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader, random_split -from pl_bolts.datamodules.cifar10_dataset import TrialCIFAR10 +from pl_bolts.datasets.cifar10_dataset import TrialCIFAR10 from pl_bolts.transforms.dataset_normalizations import cifar10_normalization try: diff --git a/pl_bolts/datamodules/imagenet_datamodule.py b/pl_bolts/datamodules/imagenet_datamodule.py index 3d458b6160..ec86b3f8f1 100644 --- a/pl_bolts/datamodules/imagenet_datamodule.py +++ b/pl_bolts/datamodules/imagenet_datamodule.py @@ -8,7 +8,7 @@ try: from torchvision import transforms as transform_lib - from pl_bolts.datamodules.imagenet_dataset import UnlabeledImagenet + from pl_bolts.datasets.imagenet_dataset import UnlabeledImagenet except ModuleNotFoundError: warn('You want to use `torchvision` which is not installed yet,' # pragma: no-cover ' install it with `pip install torchvision`.') diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py index 6858af6629..6cc6401e1e 100644 --- a/pl_bolts/datamodules/kitti_datamodule.py +++ b/pl_bolts/datamodules/kitti_datamodule.py @@ -2,7 +2,7 @@ import torch from pytorch_lightning import LightningDataModule -from pl_bolts.datamodules.kitti_dataset import KittiDataset +from pl_bolts.datasets.kitti_dataset import KittiDataset from torch.utils.data import DataLoader import torchvision.transforms as transforms diff --git a/pl_bolts/datamodules/ssl_imagenet_datamodule.py b/pl_bolts/datamodules/ssl_imagenet_datamodule.py index 56196dae6e..3ce1f6d6d8 100644 --- a/pl_bolts/datamodules/ssl_imagenet_datamodule.py +++ b/pl_bolts/datamodules/ssl_imagenet_datamodule.py @@ -4,7 +4,7 @@ from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader -from pl_bolts.datamodules.imagenet_dataset import UnlabeledImagenet +from pl_bolts.datasets.imagenet_dataset import UnlabeledImagenet from pl_bolts.transforms.dataset_normalizations import imagenet_normalization try: diff --git a/pl_bolts/datamodules/stl10_datamodule.py b/pl_bolts/datamodules/stl10_datamodule.py index 46885b25df..4d6dcc611e 100644 --- a/pl_bolts/datamodules/stl10_datamodule.py +++ b/pl_bolts/datamodules/stl10_datamodule.py @@ -5,7 +5,7 @@ from pytorch_lightning import LightningDataModule from torch.utils.data import DataLoader, random_split -from pl_bolts.datamodules.concat_dataset import ConcatDataset +from pl_bolts.datasets.concat_dataset import ConcatDataset from pl_bolts.transforms.dataset_normalizations import stl10_normalization try: diff --git a/pl_bolts/datamodules/base_dataset.py b/pl_bolts/datasets/base_dataset.py similarity index 100% rename from pl_bolts/datamodules/base_dataset.py rename to pl_bolts/datasets/base_dataset.py diff --git a/pl_bolts/datamodules/cifar10_dataset.py b/pl_bolts/datasets/cifar10_dataset.py similarity index 99% rename from pl_bolts/datamodules/cifar10_dataset.py rename to pl_bolts/datasets/cifar10_dataset.py index 5242e50b70..a1b808b219 100644 --- a/pl_bolts/datamodules/cifar10_dataset.py +++ b/pl_bolts/datasets/cifar10_dataset.py @@ -16,7 +16,7 @@ else: _PIL_AVAILABLE = True -from pl_bolts.datamodules.base_dataset import LightDataset +from pl_bolts.datasets.base_dataset import LightDataset class CIFAR10(LightDataset): diff --git a/pl_bolts/datamodules/concat_dataset.py b/pl_bolts/datasets/concat_dataset.py similarity index 100% rename from pl_bolts/datamodules/concat_dataset.py rename to pl_bolts/datasets/concat_dataset.py diff --git a/pl_bolts/datamodules/imagenet_dataset.py b/pl_bolts/datasets/imagenet_dataset.py similarity index 100% rename from pl_bolts/datamodules/imagenet_dataset.py rename to pl_bolts/datasets/imagenet_dataset.py diff --git a/pl_bolts/datamodules/kitti_dataset.py b/pl_bolts/datasets/kitti_dataset.py similarity index 100% rename from pl_bolts/datamodules/kitti_dataset.py rename to pl_bolts/datasets/kitti_dataset.py diff --git a/pl_bolts/datamodules/mnist_dataset.py b/pl_bolts/datasets/mnist_dataset.py similarity index 100% rename from pl_bolts/datamodules/mnist_dataset.py rename to pl_bolts/datasets/mnist_dataset.py diff --git a/pl_bolts/datamodules/ssl_amdim_datasets.py b/pl_bolts/datasets/ssl_amdim_datasets.py similarity index 100% rename from pl_bolts/datamodules/ssl_amdim_datasets.py rename to pl_bolts/datasets/ssl_amdim_datasets.py diff --git a/pl_bolts/models/self_supervised/amdim/datasets.py b/pl_bolts/models/self_supervised/amdim/datasets.py index 037c8bf4be..a4bc65e60b 100644 --- a/pl_bolts/models/self_supervised/amdim/datasets.py +++ b/pl_bolts/models/self_supervised/amdim/datasets.py @@ -4,8 +4,8 @@ try: from torchvision.datasets import STL10 - from pl_bolts.datamodules.imagenet_dataset import UnlabeledImagenet - from pl_bolts.datamodules.ssl_amdim_datasets import CIFAR10Mixed + from pl_bolts.datasets.imagenet_dataset import UnlabeledImagenet + from pl_bolts.datasets.ssl_amdim_datasets import CIFAR10Mixed from pl_bolts.models.self_supervised.amdim import transforms as amdim_transforms except ModuleNotFoundError: warn('You want to use `torchvision` which is not installed yet,' # pragma: no-cover diff --git a/tests/datamodules/test_dataloader.py b/tests/datamodules/test_dataloader.py index cd2a70ea52..9627a25da2 100644 --- a/tests/datamodules/test_dataloader.py +++ b/tests/datamodules/test_dataloader.py @@ -2,7 +2,7 @@ from torch.utils.data import DataLoader from pl_bolts.datamodules.async_dataloader import AsynchronousLoader -from pl_bolts.datamodules.cifar10_dataset import CIFAR10 +from pl_bolts.datasets.cifar10_dataset import CIFAR10 def test_async_dataloader(tmpdir): diff --git a/tests/datamodules/test_datamodules.py b/tests/datamodules/test_datamodules.py index 379cafb9b9..34219fcddd 100644 --- a/tests/datamodules/test_datamodules.py +++ b/tests/datamodules/test_datamodules.py @@ -1,4 +1,4 @@ -from pl_bolts.datamodules.cifar10_dataset import CIFAR10 +from pl_bolts.datasets.cifar10_dataset import CIFAR10 def test_dev_datasets(tmpdir): From b825425ebbe0fdf4677f6ad9142114ad0decd414 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 14 Oct 2020 19:46:19 +0200 Subject: [PATCH 22/32] clean imports in tests --- tests/datasets/test_datasets.py | 3 ++- tests/losses/test_rl_loss.py | 2 +- tests/models/rl/unit/test_reinforce.py | 2 +- tests/models/rl/unit/test_vpg.py | 2 +- tests/models/test_detection.py | 1 - 5 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/datasets/test_datasets.py b/tests/datasets/test_datasets.py index c7adda3cda..1e9ae6d79f 100644 --- a/tests/datasets/test_datasets.py +++ b/tests/datasets/test_datasets.py @@ -1,6 +1,7 @@ -from pl_bolts.datasets import DummyDataset, RandomDataset, RandomDictDataset, RandomDictStringDataset from torch.utils.data import DataLoader +from pl_bolts.datasets import DummyDataset, RandomDataset, RandomDictDataset, RandomDictStringDataset + def test_dummy_ds(tmpdir): ds = DummyDataset((1, 2), num_samples=100) diff --git a/tests/losses/test_rl_loss.py b/tests/losses/test_rl_loss.py index e02965f84c..7db9757d6c 100644 --- a/tests/losses/test_rl_loss.py +++ b/tests/losses/test_rl_loss.py @@ -8,8 +8,8 @@ import torch from pl_bolts.losses.rl import dqn_loss, double_dqn_loss, per_dqn_loss -from pl_bolts.models.rl.common.networks import CNN from pl_bolts.models.rl.common.gym_wrappers import make_environment +from pl_bolts.models.rl.common.networks import CNN class TestRLLoss(TestCase): diff --git a/tests/models/rl/unit/test_reinforce.py b/tests/models/rl/unit/test_reinforce.py index 655dc2bd54..9eb5ca8796 100644 --- a/tests/models/rl/unit/test_reinforce.py +++ b/tests/models/rl/unit/test_reinforce.py @@ -7,8 +7,8 @@ from pl_bolts.datamodules.experience_source import DiscountedExperienceSource from pl_bolts.models.rl.common.agents import Agent -from pl_bolts.models.rl.common.networks import MLP from pl_bolts.models.rl.common.gym_wrappers import ToTensor +from pl_bolts.models.rl.common.networks import MLP from pl_bolts.models.rl.reinforce_model import Reinforce diff --git a/tests/models/rl/unit/test_vpg.py b/tests/models/rl/unit/test_vpg.py index 0cbdb5a7c8..9812f0f906 100644 --- a/tests/models/rl/unit/test_vpg.py +++ b/tests/models/rl/unit/test_vpg.py @@ -5,8 +5,8 @@ import torch from pl_bolts.models.rl.common.agents import Agent -from pl_bolts.models.rl.common.networks import MLP from pl_bolts.models.rl.common.gym_wrappers import ToTensor +from pl_bolts.models.rl.common.networks import MLP from pl_bolts.models.rl.vanilla_policy_gradient_model import VanillaPolicyGradient diff --git a/tests/models/test_detection.py b/tests/models/test_detection.py index a312fbc9d7..730f9b5f91 100644 --- a/tests/models/test_detection.py +++ b/tests/models/test_detection.py @@ -1,4 +1,3 @@ -import pytest import pytorch_lightning as pl import torch from torch.utils.data import DataLoader From ccdc9f9952153abf9b12f00e05294fa332d5f424 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Wed, 14 Oct 2020 19:47:22 +0200 Subject: [PATCH 23/32] clean imports in pl_bolts --- pl_bolts/datamodules/kitti_datamodule.py | 8 ++++---- pl_bolts/datasets/kitti_dataset.py | 2 +- pl_bolts/models/__init__.py | 2 +- pl_bolts/models/rl/common/gym_wrappers.py | 1 + pl_bolts/models/rl/dqn_model.py | 1 + pl_bolts/models/rl/reinforce_model.py | 1 + pl_bolts/models/rl/vanilla_policy_gradient_model.py | 1 + pl_bolts/models/self_supervised/cpc/cpc_module.py | 1 - 8 files changed, 10 insertions(+), 7 deletions(-) diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py index 6cc6401e1e..127fb2e26f 100644 --- a/pl_bolts/datamodules/kitti_datamodule.py +++ b/pl_bolts/datamodules/kitti_datamodule.py @@ -1,13 +1,13 @@ import os -import torch +import torch +import torchvision.transforms as transforms from pytorch_lightning import LightningDataModule -from pl_bolts.datasets.kitti_dataset import KittiDataset - from torch.utils.data import DataLoader -import torchvision.transforms as transforms from torch.utils.data.dataset import random_split +from pl_bolts.datasets.kitti_dataset import KittiDataset + class KittiDataModule(LightningDataModule): diff --git a/pl_bolts/datasets/kitti_dataset.py b/pl_bolts/datasets/kitti_dataset.py index 937f106fa2..f887eb82b8 100644 --- a/pl_bolts/datasets/kitti_dataset.py +++ b/pl_bolts/datasets/kitti_dataset.py @@ -1,7 +1,7 @@ import os + import numpy as np from PIL import Image - from torch.utils.data import Dataset DEFAULT_VOID_LABELS = (0, 1, 2, 3, 4, 5, 6, 9, 10, 14, 15, 16, 18, 29, 30, -1) diff --git a/pl_bolts/models/__init__.py b/pl_bolts/models/__init__.py index 2fae5936c7..c7866067d2 100644 --- a/pl_bolts/models/__init__.py +++ b/pl_bolts/models/__init__.py @@ -7,5 +7,5 @@ from pl_bolts.models.mnist_module import LitMNIST from pl_bolts.models.regression import LinearRegression, LogisticRegression from pl_bolts.models.vision import PixelCNN -from pl_bolts.models.vision.image_gpt.igpt_module import GPT2, ImageGPT from pl_bolts.models.vision import UNet +from pl_bolts.models.vision.image_gpt.igpt_module import GPT2, ImageGPT diff --git a/pl_bolts/models/rl/common/gym_wrappers.py b/pl_bolts/models/rl/common/gym_wrappers.py index 8f492a27c1..16bfd2d2e0 100644 --- a/pl_bolts/models/rl/common/gym_wrappers.py +++ b/pl_bolts/models/rl/common/gym_wrappers.py @@ -9,6 +9,7 @@ import gym.spaces import numpy as np import torch + try: import cv2 except ModuleNotFoundError: diff --git a/pl_bolts/models/rl/dqn_model.py b/pl_bolts/models/rl/dqn_model.py index 01b4a68277..64bf6e7ce7 100644 --- a/pl_bolts/models/rl/dqn_model.py +++ b/pl_bolts/models/rl/dqn_model.py @@ -21,6 +21,7 @@ from pl_bolts.models.rl.common.agents import ValueAgent from pl_bolts.models.rl.common.memory import MultiStepBuffer from pl_bolts.models.rl.common.networks import CNN + try: from pl_bolts.models.rl.common.gym_wrappers import gym, make_environment except ModuleNotFoundError: diff --git a/pl_bolts/models/rl/reinforce_model.py b/pl_bolts/models/rl/reinforce_model.py index 55535a91e7..0308067ccd 100644 --- a/pl_bolts/models/rl/reinforce_model.py +++ b/pl_bolts/models/rl/reinforce_model.py @@ -17,6 +17,7 @@ from pl_bolts.datamodules.experience_source import Experience from pl_bolts.models.rl.common.agents import PolicyAgent from pl_bolts.models.rl.common.networks import MLP + try: import gym except ModuleNotFoundError: diff --git a/pl_bolts/models/rl/vanilla_policy_gradient_model.py b/pl_bolts/models/rl/vanilla_policy_gradient_model.py index f7d9e6586f..c318d25917 100644 --- a/pl_bolts/models/rl/vanilla_policy_gradient_model.py +++ b/pl_bolts/models/rl/vanilla_policy_gradient_model.py @@ -16,6 +16,7 @@ from pl_bolts.datamodules import ExperienceSourceDataset from pl_bolts.models.rl.common.agents import PolicyAgent from pl_bolts.models.rl.common.networks import MLP + try: import gym except ModuleNotFoundError: diff --git a/pl_bolts/models/self_supervised/cpc/cpc_module.py b/pl_bolts/models/self_supervised/cpc/cpc_module.py index 195f84e781..958f01978a 100644 --- a/pl_bolts/models/self_supervised/cpc/cpc_module.py +++ b/pl_bolts/models/self_supervised/cpc/cpc_module.py @@ -4,7 +4,6 @@ """ import math from argparse import ArgumentParser -from typing import Union import pytorch_lightning as pl import torch From 50d3fe150fe1bb779e1d6adbc742491c4900f8f8 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Thu, 15 Oct 2020 17:34:36 +0900 Subject: [PATCH 24/32] Remove unused exception reference (#276) --- tests/models/test_autoencoders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/test_autoencoders.py b/tests/models/test_autoencoders.py index 804fa80dcc..b27ab518f3 100644 --- a/tests/models/test_autoencoders.py +++ b/tests/models/test_autoencoders.py @@ -110,7 +110,7 @@ def test_from_pretrained(tmpdir): x_hat = ae(x) break - except Exception as e: + except Exception: exception_raised = True assert exception_raised is False, "error in loading weights" From 3d71ac69e01db18835f0d78ad148a48a5ee34f31 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Thu, 15 Oct 2020 11:08:54 +0200 Subject: [PATCH 25/32] fix SSL hooks fro pl 1.0 (#277) * pl 1.0 compatible * fix --- pl_bolts/callbacks/self_supervised.py | 4 ++-- pl_bolts/callbacks/vision/confused_logit.py | 2 +- pl_bolts/models/self_supervised/byol/byol_module.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pl_bolts/callbacks/self_supervised.py b/pl_bolts/callbacks/self_supervised.py index 4ff7216a18..9fb7bf57f1 100644 --- a/pl_bolts/callbacks/self_supervised.py +++ b/pl_bolts/callbacks/self_supervised.py @@ -72,7 +72,7 @@ def to_device(self, batch, device): return x, y - def on_train_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx): + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): x, y = self.to_device(batch, pl_module.device) with torch.no_grad(): @@ -131,7 +131,7 @@ def __init__(self, initial_tau=0.996): self.initial_tau = initial_tau self.current_tau = initial_tau - def on_train_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx): + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): # get networks online_net = pl_module.online_network target_net = pl_module.target_network diff --git a/pl_bolts/callbacks/vision/confused_logit.py b/pl_bolts/callbacks/vision/confused_logit.py index 3e6f92da50..2af5d0d72a 100644 --- a/pl_bolts/callbacks/vision/confused_logit.py +++ b/pl_bolts/callbacks/vision/confused_logit.py @@ -55,7 +55,7 @@ def __init__( self.logging_batch_interval = logging_batch_interval self.min_logit_value = min_logit_value - def on_train_batch_end(self, trainer, pl_module, batch, batch_idx, dataloader_idx): + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): # show images only every 20 batches if (trainer.batch_idx + 1) % self.logging_batch_interval != 0: return diff --git a/pl_bolts/models/self_supervised/byol/byol_module.py b/pl_bolts/models/self_supervised/byol/byol_module.py index 95c68bbee7..b983988ea7 100644 --- a/pl_bolts/models/self_supervised/byol/byol_module.py +++ b/pl_bolts/models/self_supervised/byol/byol_module.py @@ -97,9 +97,9 @@ def __init__(self, self.target_network = deepcopy(self.online_network) self.weight_callback = BYOLMAWeightUpdate() - def on_train_batch_end(self, batch: Any, batch_idx: int, dataloader_idx: int) -> None: + def on_train_batch_end(self, outputs, batch: Any, batch_idx: int, dataloader_idx: int) -> None: # Add callback for user automatically since it's key to BYOL weight update - self.weight_callback.on_train_batch_end(self.trainer, self, batch, batch_idx, dataloader_idx) + self.weight_callback.on_train_batch_end(self.trainer, self, outputs, batch, batch_idx, dataloader_idx) def forward(self, x): y, _, _ = self.online_network(x) From bdfe159482877c4397ddaa0712ed4043b279a03b Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Thu, 15 Oct 2020 12:40:28 +0200 Subject: [PATCH 26/32] consistent class docs (#278) --- pl_bolts/callbacks/printing.py | 35 +-- pl_bolts/callbacks/self_supervised.py | 56 ++-- pl_bolts/callbacks/variational.py | 19 +- pl_bolts/callbacks/vision/confused_logit.py | 54 ++-- pl_bolts/callbacks/vision/image_generation.py | 29 ++- .../datamodules/binary_mnist_datamodule.py | 51 ++-- pl_bolts/datamodules/cifar10_datamodule.py | 112 ++++---- pl_bolts/datamodules/cityscapes_datamodule.py | 75 +++--- pl_bolts/datamodules/experience_source.py | 18 +- .../datamodules/fashion_mnist_datamodule.py | 53 ++-- pl_bolts/datamodules/imagenet_datamodule.py | 56 ++-- pl_bolts/datamodules/mnist_datamodule.py | 51 ++-- pl_bolts/datamodules/sklearn_datamodule.py | 111 ++++---- pl_bolts/datamodules/stl10_datamodule.py | 61 ++--- .../datamodules/vocdetection_datamodule.py | 7 +- pl_bolts/datasets/cifar10_dataset.py | 24 +- pl_bolts/datasets/dummy_dataset.py | 45 ++-- pl_bolts/datasets/imagenet_dataset.py | 11 +- pl_bolts/datasets/kitti_dataset.py | 13 +- pl_bolts/datasets/ssl_amdim_datasets.py | 8 - pl_bolts/losses/self_supervised_learning.py | 74 +++--- .../autoencoders/basic_ae/basic_ae_module.py | 26 +- .../basic_vae/basic_vae_module.py | 32 +-- pl_bolts/models/detection/faster_rcnn.py | 35 +-- .../models/gans/basic/basic_gan_module.py | 41 ++- .../models/regression/linear_regression.py | 8 +- .../models/regression/logistic_regression.py | 6 +- pl_bolts/models/rl/common/memory.py | 17 +- pl_bolts/models/rl/common/networks.py | 59 +++-- pl_bolts/models/rl/dqn_model.py | 52 ++-- pl_bolts/models/rl/dueling_dqn_model.py | 54 ++-- pl_bolts/models/rl/per_dqn_model.py | 40 +-- pl_bolts/models/rl/reinforce_model.py | 51 ++-- .../rl/vanilla_policy_gradient_model.py | 49 ++-- .../self_supervised/amdim/amdim_module.py | 43 ++-- .../models/self_supervised/amdim/networks.py | 16 +- .../self_supervised/amdim/transforms.py | 164 ++++++------ .../self_supervised/byol/byol_module.py | 103 ++++---- .../models/self_supervised/cpc/transforms.py | 241 +++++++++--------- .../self_supervised/moco/moco2_module.py | 55 ++-- .../self_supervised/simclr/transforms.py | 5 +- .../models/self_supervised/ssl_finetuner.py | 49 ++-- pl_bolts/models/vision/image_gpt/gpt2.py | 43 ++-- .../models/vision/image_gpt/igpt_module.py | 160 ++++++------ pl_bolts/models/vision/pixel_cnn.py | 36 +-- pl_bolts/models/vision/unet.py | 16 +- pl_bolts/optimizers/lars_scheduling.py | 6 +- pl_bolts/optimizers/lr_scheduler.py | 18 +- .../self_supervised/ssl_transforms.py | 4 +- pl_bolts/utils/arguments.py | 22 +- 50 files changed, 1233 insertions(+), 1181 deletions(-) diff --git a/pl_bolts/callbacks/printing.py b/pl_bolts/callbacks/printing.py index 60c9435037..d2364b4085 100644 --- a/pl_bolts/callbacks/printing.py +++ b/pl_bolts/callbacks/printing.py @@ -7,32 +7,33 @@ class PrintTableMetricsCallback(Callback): - def __init__(self): - """ - Prints a table with the metrics in columns on every epoch end + """ + Prints a table with the metrics in columns on every epoch end - Example:: + Example:: - from pl_bolts.callbacks import PrintTableMetricsCallback + from pl_bolts.callbacks import PrintTableMetricsCallback - callback = PrintTableMetricsCallback() + callback = PrintTableMetricsCallback() - pass into trainer like so: + pass into trainer like so: - .. code-block:: python + .. code-block:: python - trainer = pl.Trainer(callbacks=[callback]) - trainer.fit(...) + trainer = pl.Trainer(callbacks=[callback]) + trainer.fit(...) - # ------------------------------ - # at the end of every epoch it will print - # ------------------------------ + # ------------------------------ + # at the end of every epoch it will print + # ------------------------------ - # loss│train_loss│val_loss│epoch - # ────────────────────────────── - # 2.2541470527648926│2.2541470527648926│2.2158432006835938│0 + # loss│train_loss│val_loss│epoch + # ────────────────────────────── + # 2.2541470527648926│2.2541470527648926│2.2158432006835938│0 - """ + """ + + def __init__(self): self.metrics = [] def on_epoch_end(self, trainer, pl_module): diff --git a/pl_bolts/callbacks/self_supervised.py b/pl_bolts/callbacks/self_supervised.py index 9fb7bf57f1..63c17e6d5d 100644 --- a/pl_bolts/callbacks/self_supervised.py +++ b/pl_bolts/callbacks/self_supervised.py @@ -7,20 +7,21 @@ class SSLOnlineEvaluator(pl.Callback): # pragma: no-cover + """ + Attaches a MLP for finetuning using the standard self-supervised protocol. - def __init__(self, drop_p: float = 0.2, hidden_dim: int = 1024, z_dim: int = None, num_classes: int = None): - """ - Attaches a MLP for finetuning using the standard self-supervised protocol. - - Example:: + Example:: - from pl_bolts.callbacks.self_supervised import SSLOnlineEvaluator + from pl_bolts.callbacks.self_supervised import SSLOnlineEvaluator - # your model must have 2 attributes - model = Model() - model.z_dim = ... # the representation dim - model.num_classes = ... # the num of classes in the model + # your model must have 2 attributes + model = Model() + model.z_dim = ... # the representation dim + model.num_classes = ... # the num of classes in the model + """ + def __init__(self, drop_p: float = 0.2, hidden_dim: int = 1024, z_dim: int = None, num_classes: int = None): + """ Args: drop_p: (0.2) dropout probability hidden_dim: (1024) the hidden dimension for the finetune MLP @@ -98,32 +99,33 @@ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, data class BYOLMAWeightUpdate(pl.Callback): + """ + Weight update rule from BYOL. - def __init__(self, initial_tau=0.996): - """ - Weight update rule from BYOL. - - Your model should have a: + Your model should have a: - - self.online_network. - - self.target_network. + - self.online_network. + - self.target_network. - Updates the target_network params using an exponential moving average update rule weighted by tau. - BYOL claims this keeps the online_network from collapsing. + Updates the target_network params using an exponential moving average update rule weighted by tau. + BYOL claims this keeps the online_network from collapsing. - .. note:: Automatically increases tau from `initial_tau` to 1.0 with every training step + .. note:: Automatically increases tau from `initial_tau` to 1.0 with every training step - Example:: + Example:: - from pl_bolts.callbacks.self_supervised import BYOLMAWeightUpdate + from pl_bolts.callbacks.self_supervised import BYOLMAWeightUpdate - # model must have 2 attributes - model = Model() - model.online_network = ... - model.target_network = ... + # model must have 2 attributes + model = Model() + model.online_network = ... + model.target_network = ... - trainer = Trainer(callbacks=[BYOLMAWeightUpdate()]) + trainer = Trainer(callbacks=[BYOLMAWeightUpdate()]) + """ + def __init__(self, initial_tau=0.996): + """ Args: initial_tau: starting tau. Auto-updates with every training step """ diff --git a/pl_bolts/callbacks/variational.py b/pl_bolts/callbacks/variational.py index eb4469b0ed..c845f3ecc2 100644 --- a/pl_bolts/callbacks/variational.py +++ b/pl_bolts/callbacks/variational.py @@ -11,20 +11,21 @@ class LatentDimInterpolator(Callback): + """ + Interpolates the latent space for a model by setting all dims to zero and stepping + through the first two dims increasing one unit at a time. - def __init__(self, interpolate_epoch_interval=20, range_start=-5, range_end=5, num_samples=2): - """ - Interpolates the latent space for a model by setting all dims to zero and stepping - through the first two dims increasing one unit at a time. - - Default interpolates between [-5, 5] (-5, -4, -3, ..., 3, 4, 5) + Default interpolates between [-5, 5] (-5, -4, -3, ..., 3, 4, 5) - Example:: + Example:: - from pl_bolts.callbacks import LatentDimInterpolator + from pl_bolts.callbacks import LatentDimInterpolator - Trainer(callbacks=[LatentDimInterpolator()]) + Trainer(callbacks=[LatentDimInterpolator()]) + """ + def __init__(self, interpolate_epoch_interval=20, range_start=-5, range_end=5, num_samples=2): + """ Args: interpolate_epoch_interval: range_start: default -5 diff --git a/pl_bolts/callbacks/vision/confused_logit.py b/pl_bolts/callbacks/vision/confused_logit.py index 2af5d0d72a..f8ab81a5d6 100644 --- a/pl_bolts/callbacks/vision/confused_logit.py +++ b/pl_bolts/callbacks/vision/confused_logit.py @@ -4,49 +4,49 @@ class ConfusedLogitCallback(Callback): # pragma: no-cover + """ + Takes the logit predictions of a model and when the probabilities of two classes are very close, the model + doesn't have high certainty that it should pick one vs the other class. - def __init__( - self, - top_k, - projection_factor=3, - min_logit_value=5.0, - logging_batch_interval=20, - max_logit_difference=0.1 - ): - """ - Takes the logit predictions of a model and when the probabilities of two classes are very close, the model - doesn't have high certainty that it should pick one vs the other class. + This callback shows how the input would have to change to swing the model from one label prediction + to the other. - This callback shows how the input would have to change to swing the model from one label prediction - to the other. + In this case, the network predicts a 5... but gives almost equal probability to an 8. + The images show what about the original 5 would have to change to make it more like a 5 or more like an 8. - In this case, the network predicts a 5... but gives almost equal probability to an 8. - The images show what about the original 5 would have to change to make it more like a 5 or more like an 8. + For each confused logit the confused images are generated by taking the gradient from a logit wrt an input + for the top two closest logits. - For each confused logit the confused images are generated by taking the gradient from a logit wrt an input - for the top two closest logits. + Example:: - Example:: + from pl_bolts.callbacks.vision import ConfusedLogitCallback + trainer = Trainer(callbacks=[ConfusedLogitCallback()]) - from pl_bolts.callbacks.vision import ConfusedLogitCallback - trainer = Trainer(callbacks=[ConfusedLogitCallback()]) + .. note:: whenever called, this model will look for self.last_batch and self.last_logits in the LightningModule - .. note:: whenever called, this model will look for self.last_batch and self.last_logits in the LightningModule + .. note:: this callback supports tensorboard only right now - .. note:: this callback supports tensorboard only right now + Authored by: + - Alfredo Canziani + """ + + def __init__( + self, + top_k, + projection_factor=3, + min_logit_value=5.0, + logging_batch_interval=20, + max_logit_difference=0.1 + ): + """ Args: top_k: How many "offending" images we should plot projection_factor: How much to multiply the input image to make it look more like this logit label min_logit_value: Only consider logit values above this threshold logging_batch_interval: how frequently to inspect/potentially plot something max_logit_difference: when the top 2 logits are within this threshold we consider them confused - - Authored by: - - - Alfredo Canziani - """ super().__init__() self.top_k = top_k diff --git a/pl_bolts/callbacks/vision/image_generation.py b/pl_bolts/callbacks/vision/image_generation.py index 60f5c2e172..f4ffb0b0ea 100644 --- a/pl_bolts/callbacks/vision/image_generation.py +++ b/pl_bolts/callbacks/vision/image_generation.py @@ -11,26 +11,27 @@ class TensorboardGenerativeModelImageSampler(Callback): - def __init__(self, num_samples: int = 3): - """ - Generates images and logs to tensorboard. - Your model must implement the forward function for generation + """ + Generates images and logs to tensorboard. + Your model must implement the forward function for generation + + Requirements:: - Requirements:: + # model must have img_dim arg + model.img_dim = (1, 28, 28) - # model must have img_dim arg - model.img_dim = (1, 28, 28) + # model forward must work for sampling + z = torch.rand(batch_size, latent_dim) + img_samples = your_model(z) - # model forward must work for sampling - z = torch.rand(batch_size, latent_dim) - img_samples = your_model(z) + Example:: - Example:: + from pl_bolts.callbacks import TensorboardGenerativeModelImageSampler - from pl_bolts.callbacks import TensorboardGenerativeModelImageSampler + trainer = Trainer(callbacks=[TensorboardGenerativeModelImageSampler()]) + """ - trainer = Trainer(callbacks=[TensorboardGenerativeModelImageSampler()]) - """ + def __init__(self, num_samples: int = 3): super().__init__() self.num_samples = num_samples diff --git a/pl_bolts/datamodules/binary_mnist_datamodule.py b/pl_bolts/datamodules/binary_mnist_datamodule.py index ac5c5467d0..5b982e69ca 100644 --- a/pl_bolts/datamodules/binary_mnist_datamodule.py +++ b/pl_bolts/datamodules/binary_mnist_datamodule.py @@ -17,6 +17,32 @@ class BinaryMNISTDataModule(LightningDataModule): + """ + .. figure:: https://miro.medium.com/max/744/1*AO2rIhzRYzFVQlFLx9DM9A.png + :width: 400 + :alt: MNIST + + Specs: + - 10 classes (1 per digit) + - Each image is (1 x 28 x 28) + + Binary MNIST, train, val, test splits and transforms + + Transforms:: + + mnist_transforms = transform_lib.Compose([ + transform_lib.ToTensor() + ]) + + Example:: + + from pl_bolts.datamodules import BinaryMNISTDataModule + + dm = BinaryMNISTDataModule('.') + model = LitModel() + + Trainer().fit(model, dm) + """ name = 'mnist' @@ -31,31 +57,6 @@ def __init__( **kwargs, ): """ - .. figure:: https://miro.medium.com/max/744/1*AO2rIhzRYzFVQlFLx9DM9A.png - :width: 400 - :alt: MNIST - - Specs: - - 10 classes (1 per digit) - - Each image is (1 x 28 x 28) - - Binary MNIST, train, val, test splits and transforms - - Transforms:: - - mnist_transforms = transform_lib.Compose([ - transform_lib.ToTensor() - ]) - - Example:: - - from pl_bolts.datamodules import BinaryMNISTDataModule - - dm = BinaryMNISTDataModule('.') - model = LitModel() - - Trainer().fit(model, dm) - Args: data_dir: where to save/load the data val_split: how many of the training images to use for the validation split diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py index 80d44c17df..b6e80ef53a 100644 --- a/pl_bolts/datamodules/cifar10_datamodule.py +++ b/pl_bolts/datamodules/cifar10_datamodule.py @@ -22,6 +22,45 @@ class CIFAR10DataModule(LightningDataModule): + """ + .. figure:: https://3qeqpr26caki16dnhd19sv6by6v-wpengine.netdna-ssl.com/wp-content/uploads/2019/01/ + Plot-of-a-Subset-of-Images-from-the-CIFAR-10-Dataset.png + :width: 400 + :alt: CIFAR-10 + + Specs: + - 10 classes (1 per class) + - Each image is (3 x 32 x 32) + + Standard CIFAR10, train, val, test splits and transforms + + Transforms:: + + mnist_transforms = transform_lib.Compose([ + transform_lib.ToTensor(), + transforms.Normalize( + mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], + std=[x / 255.0 for x in [63.0, 62.1, 66.7]] + ) + ]) + + Example:: + + from pl_bolts.datamodules import CIFAR10DataModule + + dm = CIFAR10DataModule(PATH) + model = LitModel() + + Trainer().fit(model, dm) + + Or you can set your own transforms + + Example:: + + dm.train_transforms = ... + dm.test_transforms = ... + dm.val_transforms = ... + """ name = 'cifar10' extra_args = {} @@ -37,44 +76,6 @@ def __init__( **kwargs, ): """ - .. figure:: https://3qeqpr26caki16dnhd19sv6by6v-wpengine.netdna-ssl.com/wp-content/uploads/2019/01/ - Plot-of-a-Subset-of-Images-from-the-CIFAR-10-Dataset.png - :width: 400 - :alt: CIFAR-10 - - Specs: - - 10 classes (1 per class) - - Each image is (3 x 32 x 32) - - Standard CIFAR10, train, val, test splits and transforms - - Transforms:: - - mnist_transforms = transform_lib.Compose([ - transform_lib.ToTensor(), - transforms.Normalize( - mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], - std=[x / 255.0 for x in [63.0, 62.1, 66.7]] - ) - ]) - - Example:: - - from pl_bolts.datamodules import CIFAR10DataModule - - dm = CIFAR10DataModule(PATH) - model = LitModel() - - Trainer().fit(model, dm) - - Or you can set your own transforms - - Example:: - - dm.train_transforms = ... - dm.test_transforms = ... - dm.val_transforms = ... - Args: data_dir: where to save/load the data val_split: how many of the training images to use for the validation split @@ -184,6 +185,24 @@ def default_transforms(self): class TinyCIFAR10DataModule(CIFAR10DataModule): + """ + Standard CIFAR10, train, val, test splits and transforms + + Transforms:: + + mnist_transforms = transform_lib.Compose([ + transform_lib.ToTensor(), + transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], + std=[x / 255.0 for x in [63.0, 62.1, 66.7]]) + ]) + + Example:: + + from pl_bolts.datamodules import CIFAR10DataModule + + dm = CIFAR10DataModule(PATH) + model = LitModel(datamodule=dm) + """ def __init__( self, @@ -196,23 +215,6 @@ def __init__( **kwargs, ): """ - Standard CIFAR10, train, val, test splits and transforms - - Transforms:: - - mnist_transforms = transform_lib.Compose([ - transform_lib.ToTensor(), - transforms.Normalize(mean=[x / 255.0 for x in [125.3, 123.0, 113.9]], - std=[x / 255.0 for x in [63.0, 62.1, 66.7]]) - ]) - - Example:: - - from pl_bolts.datamodules import CIFAR10DataModule - - dm = CIFAR10DataModule(PATH) - model = LitModel(datamodule=dm) - Args: data_dir: where to save/load the data val_split: how many of the training images to use for the validation split diff --git a/pl_bolts/datamodules/cityscapes_datamodule.py b/pl_bolts/datamodules/cityscapes_datamodule.py index 6983e932ff..a43870ce24 100644 --- a/pl_bolts/datamodules/cityscapes_datamodule.py +++ b/pl_bolts/datamodules/cityscapes_datamodule.py @@ -16,6 +16,44 @@ class CityscapesDataModule(LightningDataModule): + """ + .. figure:: https://www.cityscapes-dataset.com/wordpress/wp-content/uploads/2015/07/muenster00-1024x510.png + :width: 400 + :alt: Cityscape + + Standard Cityscapes, train, val, test splits and transforms + + Specs: + - 30 classes (road, person, sidewalk, etc...) + - (image, target) - image dims: (3 x 32 x 32), target dims: (3 x 32 x 32) + + Transforms:: + + transforms = transform_lib.Compose([ + transform_lib.ToTensor(), + transform_lib.Normalize( + mean=[0.28689554, 0.32513303, 0.28389177], + std=[0.18696375, 0.19017339, 0.18720214] + ) + ]) + + Example:: + + from pl_bolts.datamodules import CityscapesDataModule + + dm = CityscapesDataModule(PATH) + model = LitModel() + + Trainer().fit(model, dm) + + Or you can set your own transforms + + Example:: + + dm.train_transforms = ... + dm.test_transforms = ... + dm.val_transforms = ... + """ name = 'Cityscapes' extra_args = {} @@ -31,43 +69,6 @@ def __init__( **kwargs, ): """ - .. figure:: https://www.cityscapes-dataset.com/wordpress/wp-content/uploads/2015/07/muenster00-1024x510.png - :width: 400 - :alt: Cityscape - - Standard Cityscapes, train, val, test splits and transforms - - Specs: - - 30 classes (road, person, sidewalk, etc...) - - (image, target) - image dims: (3 x 32 x 32), target dims: (3 x 32 x 32) - - Transforms:: - - transforms = transform_lib.Compose([ - transform_lib.ToTensor(), - transform_lib.Normalize( - mean=[0.28689554, 0.32513303, 0.28389177], - std=[0.18696375, 0.19017339, 0.18720214] - ) - ]) - - Example:: - - from pl_bolts.datamodules import CityscapesDataModule - - dm = CityscapesDataModule(PATH) - model = LitModel() - - Trainer().fit(model, dm) - - Or you can set your own transforms - - Example:: - - dm.train_transforms = ... - dm.test_transforms = ... - dm.val_transforms = ... - Args: data_dir: where to save/load the data val_split: how many of the training images to use for the validation split diff --git a/pl_bolts/datamodules/experience_source.py b/pl_bolts/datamodules/experience_source.py index 6a4671234f..bfdd66d529 100644 --- a/pl_bolts/datamodules/experience_source.py +++ b/pl_bolts/datamodules/experience_source.py @@ -35,12 +35,14 @@ def __iter__(self) -> Iterable: class BaseExperienceSource(ABC): """ Simplest form of the experience source - Args: - env: Environment that is being used - agent: Agent being used to make decisions """ def __init__(self, env, agent) -> None: + """ + Args: + env: Environment that is being used + agent: Agent being used to make decisions + """ self.env = env self.agent = agent @@ -52,13 +54,15 @@ def runner(self) -> Experience: class ExperienceSource(BaseExperienceSource): """ Experience source class handling single and multiple environment steps - Args: - env: Environment that is being used - agent: Agent being used to make decisions - n_steps: Number of steps to return from each environment at once """ def __init__(self, env, agent, n_steps: int = 1) -> None: + """ + Args: + env: Environment that is being used + agent: Agent being used to make decisions + n_steps: Number of steps to return from each environment at once + """ super().__init__(env, agent) self.pool = env if isinstance(env, (list, tuple)) else [env] diff --git a/pl_bolts/datamodules/fashion_mnist_datamodule.py b/pl_bolts/datamodules/fashion_mnist_datamodule.py index 467bc36c4c..4237d1ef01 100644 --- a/pl_bolts/datamodules/fashion_mnist_datamodule.py +++ b/pl_bolts/datamodules/fashion_mnist_datamodule.py @@ -16,6 +16,33 @@ class FashionMNISTDataModule(LightningDataModule): + """ + .. figure:: https://3qeqpr26caki16dnhd19sv6by6v-wpengine.netdna-ssl.com/ + wp-content/uploads/2019/02/Plot-of-a-Subset-of-Images-from-the-Fashion-MNIST-Dataset.png + :width: 400 + :alt: Fashion MNIST + + Specs: + - 10 classes (1 per type) + - Each image is (1 x 28 x 28) + + Standard FashionMNIST, train, val, test splits and transforms + + Transforms:: + + mnist_transforms = transform_lib.Compose([ + transform_lib.ToTensor() + ]) + + Example:: + + from pl_bolts.datamodules import FashionMNISTDataModule + + dm = FashionMNISTDataModule('.') + model = LitModel() + + Trainer().fit(model, dm) + """ name = 'fashion_mnist' @@ -29,32 +56,6 @@ def __init__( **kwargs, ): """ - .. figure:: https://3qeqpr26caki16dnhd19sv6by6v-wpengine.netdna-ssl.com/ - wp-content/uploads/2019/02/Plot-of-a-Subset-of-Images-from-the-Fashion-MNIST-Dataset.png - :width: 400 - :alt: Fashion MNIST - - Specs: - - 10 classes (1 per type) - - Each image is (1 x 28 x 28) - - Standard FashionMNIST, train, val, test splits and transforms - - Transforms:: - - mnist_transforms = transform_lib.Compose([ - transform_lib.ToTensor() - ]) - - Example:: - - from pl_bolts.datamodules import FashionMNISTDataModule - - dm = FashionMNISTDataModule('.') - model = LitModel() - - Trainer().fit(model, dm) - Args: data_dir: where to save/load the data val_split: how many of the training images to use for the validation split diff --git a/pl_bolts/datamodules/imagenet_datamodule.py b/pl_bolts/datamodules/imagenet_datamodule.py index ec86b3f8f1..4ab505989e 100644 --- a/pl_bolts/datamodules/imagenet_datamodule.py +++ b/pl_bolts/datamodules/imagenet_datamodule.py @@ -18,6 +18,34 @@ class ImagenetDataModule(LightningDataModule): + """ + .. figure:: https://3qeqpr26caki16dnhd19sv6by6v-wpengine.netdna-ssl.com/wp-content/uploads/2017/08/ + Sample-of-Images-from-the-ImageNet-Dataset-used-in-the-ILSVRC-Challenge.png + :width: 400 + :alt: Imagenet + + Specs: + - 1000 classes + - Each image is (3 x varies x varies) (here we default to 3 x 224 x 224) + + Imagenet train, val and test dataloaders. + + The train set is the imagenet train. + + The val set is taken from the train set with `num_imgs_per_val_class` images per class. + For example if `num_imgs_per_val_class=2` then there will be 2,000 images in the validation set. + + The test set is the official imagenet validation set. + + Example:: + + from pl_bolts.datamodules import ImagenetDataModule + + dm = ImagenetDataModule(IMAGENET_PATH) + model = LitModel() + + Trainer().fit(model, dm) + """ name = 'imagenet' @@ -33,35 +61,7 @@ def __init__( **kwargs, ): """ - .. figure:: https://3qeqpr26caki16dnhd19sv6by6v-wpengine.netdna-ssl.com/wp-content/uploads/2017/08/ - Sample-of-Images-from-the-ImageNet-Dataset-used-in-the-ILSVRC-Challenge.png - :width: 400 - :alt: Imagenet - - Specs: - - 1000 classes - - Each image is (3 x varies x varies) (here we default to 3 x 224 x 224) - - Imagenet train, val and test dataloaders. - - The train set is the imagenet train. - - The val set is taken from the train set with `num_imgs_per_val_class` images per class. - For example if `num_imgs_per_val_class=2` then there will be 2,000 images in the validation set. - - The test set is the official imagenet validation set. - - Example:: - - from pl_bolts.datamodules import ImagenetDataModule - - dm = ImagenetDataModule(IMAGENET_PATH) - model = LitModel() - - Trainer().fit(model, dm) - Args: - data_dir: path to the imagenet dataset file meta_dir: path to meta.bin file num_imgs_per_val_class: how many images per class for the validation set diff --git a/pl_bolts/datamodules/mnist_datamodule.py b/pl_bolts/datamodules/mnist_datamodule.py index d1964e8346..ea416442c8 100644 --- a/pl_bolts/datamodules/mnist_datamodule.py +++ b/pl_bolts/datamodules/mnist_datamodule.py @@ -16,6 +16,32 @@ class MNISTDataModule(LightningDataModule): + """ + .. figure:: https://miro.medium.com/max/744/1*AO2rIhzRYzFVQlFLx9DM9A.png + :width: 400 + :alt: MNIST + + Specs: + - 10 classes (1 per digit) + - Each image is (1 x 28 x 28) + + Standard MNIST, train, val, test splits and transforms + + Transforms:: + + mnist_transforms = transform_lib.Compose([ + transform_lib.ToTensor() + ]) + + Example:: + + from pl_bolts.datamodules import MNISTDataModule + + dm = MNISTDataModule('.') + model = LitModel() + + Trainer().fit(model, dm) + """ name = "mnist" @@ -31,31 +57,6 @@ def __init__( **kwargs, ): """ - .. figure:: https://miro.medium.com/max/744/1*AO2rIhzRYzFVQlFLx9DM9A.png - :width: 400 - :alt: MNIST - - Specs: - - 10 classes (1 per digit) - - Each image is (1 x 28 x 28) - - Standard MNIST, train, val, test splits and transforms - - Transforms:: - - mnist_transforms = transform_lib.Compose([ - transform_lib.ToTensor() - ]) - - Example:: - - from pl_bolts.datamodules import MNISTDataModule - - dm = MNISTDataModule('.') - model = LitModel() - - Trainer().fit(model, dm) - Args: data_dir: where to save/load the data val_split: how many of the training images to use for the validation split diff --git a/pl_bolts/datamodules/sklearn_datamodule.py b/pl_bolts/datamodules/sklearn_datamodule.py index 704d629c1b..81eda9649d 100644 --- a/pl_bolts/datamodules/sklearn_datamodule.py +++ b/pl_bolts/datamodules/sklearn_datamodule.py @@ -18,25 +18,25 @@ class SklearnDataset(Dataset): + """ + Mapping between numpy (or sklearn) datasets to PyTorch datasets. + + Example: + >>> from sklearn.datasets import load_boston + >>> from pl_bolts.datamodules import SklearnDataset + ... + >>> X, y = load_boston(return_X_y=True) + >>> dataset = SklearnDataset(X, y) + >>> len(dataset) + 506 + """ def __init__(self, X: np.ndarray, y: np.ndarray, X_transform: Any = None, y_transform: Any = None): """ - Mapping between numpy (or sklearn) datasets to PyTorch datasets. - Args: X: Numpy ndarray y: Numpy ndarray X_transform: Any transform that works with Numpy arrays y_transform: Any transform that works with Numpy arrays - - Example: - >>> from sklearn.datasets import load_boston - >>> from pl_bolts.datamodules import SklearnDataset - ... - >>> X, y = load_boston(return_X_y=True) - >>> dataset = SklearnDataset(X, y) - >>> len(dataset) - 506 - """ super().__init__() self.X = X @@ -65,25 +65,25 @@ def __getitem__(self, idx): class TensorDataset(Dataset): + """ + Prepare PyTorch tensor dataset for data loaders. + + Example: + >>> from pl_bolts.datamodules import TensorDataset + ... + >>> X = torch.rand(10, 3) + >>> y = torch.rand(10) + >>> dataset = TensorDataset(X, y) + >>> len(dataset) + 10 + """ def __init__(self, X: torch.Tensor, y: torch.Tensor, X_transform: Any = None, y_transform: Any = None): """ - Prepare PyTorch tensor dataset for data loaders. - Args: X: PyTorch tensor y: PyTorch tensor X_transform: Any transform that works with PyTorch tensors y_transform: Any transform that works with PyTorch tensors - - Example: - >>> from pl_bolts.datamodules import TensorDataset - ... - >>> X = torch.rand(10, 3) - >>> y = torch.rand(10) - >>> dataset = TensorDataset(X, y) - >>> len(dataset) - 10 - """ super().__init__() self.X = X @@ -108,6 +108,37 @@ def __getitem__(self, idx): class SklearnDataModule(LightningDataModule): + """ + Automatically generates the train, validation and test splits for a Numpy dataset. They are set up as + dataloaders for convenience. Optionally, you can pass in your own validation and test splits. + + Example: + + >>> from sklearn.datasets import load_boston + >>> from pl_bolts.datamodules import SklearnDataModule + ... + >>> X, y = load_boston(return_X_y=True) + >>> loaders = SklearnDataModule(X, y) + ... + >>> # train set + >>> train_loader = loaders.train_dataloader(batch_size=32) + >>> len(train_loader.dataset) + 355 + >>> len(train_loader) + 11 + >>> # validation set + >>> val_loader = loaders.val_dataloader(batch_size=32) + >>> len(val_loader.dataset) + 100 + >>> len(val_loader) + 3 + >>> # test set + >>> test_loader = loaders.test_dataloader(batch_size=32) + >>> len(test_loader.dataset) + 51 + >>> len(test_loader) + 1 + """ name = 'sklearn' @@ -122,38 +153,6 @@ def __init__( *args, **kwargs, ): - """ - Automatically generates the train, validation and test splits for a Numpy dataset. They are set up as - dataloaders for convenience. Optionally, you can pass in your own validation and test splits. - - Example: - - >>> from sklearn.datasets import load_boston - >>> from pl_bolts.datamodules import SklearnDataModule - ... - >>> X, y = load_boston(return_X_y=True) - >>> loaders = SklearnDataModule(X, y) - ... - >>> # train set - >>> train_loader = loaders.train_dataloader(batch_size=32) - >>> len(train_loader.dataset) - 355 - >>> len(train_loader) - 11 - >>> # validation set - >>> val_loader = loaders.val_dataloader(batch_size=32) - >>> len(val_loader.dataset) - 100 - >>> len(val_loader) - 3 - >>> # test set - >>> test_loader = loaders.test_dataloader(batch_size=32) - >>> len(test_loader.dataset) - 51 - >>> len(test_loader) - 1 - - """ super().__init__(*args, **kwargs) self.num_workers = num_workers diff --git a/pl_bolts/datamodules/stl10_datamodule.py b/pl_bolts/datamodules/stl10_datamodule.py index 4d6dcc611e..df49164579 100644 --- a/pl_bolts/datamodules/stl10_datamodule.py +++ b/pl_bolts/datamodules/stl10_datamodule.py @@ -20,6 +20,37 @@ class STL10DataModule(LightningDataModule): # pragma: no cover + """ + .. figure:: https://samyzaf.com/ML/cifar10/cifar1.jpg + :width: 400 + :alt: STL-10 + + Specs: + - 10 classes (1 per type) + - Each image is (3 x 96 x 96) + + Standard STL-10, train, val, test splits and transforms. + STL-10 has support for doing validation splits on the labeled or unlabeled splits + + Transforms:: + + mnist_transforms = transform_lib.Compose([ + transform_lib.ToTensor(), + transforms.Normalize( + mean=(0.43, 0.42, 0.39), + std=(0.27, 0.26, 0.27) + ) + ]) + + Example:: + + from pl_bolts.datamodules import STL10DataModule + + dm = STL10DataModule(PATH) + model = LitModel() + + Trainer().fit(model, dm) + """ name = 'stl10' @@ -35,36 +66,6 @@ def __init__( **kwargs, ): """ - .. figure:: https://samyzaf.com/ML/cifar10/cifar1.jpg - :width: 400 - :alt: STL-10 - - Specs: - - 10 classes (1 per type) - - Each image is (3 x 96 x 96) - - Standard STL-10, train, val, test splits and transforms. - STL-10 has support for doing validation splits on the labeled or unlabeled splits - - Transforms:: - - mnist_transforms = transform_lib.Compose([ - transform_lib.ToTensor(), - transforms.Normalize( - mean=(0.43, 0.42, 0.39), - std=(0.27, 0.26, 0.27) - ) - ]) - - Example:: - - from pl_bolts.datamodules import STL10DataModule - - dm = STL10DataModule(PATH) - model = LitModel() - - Trainer().fit(model, dm) - Args: data_dir: where to save/load the data unlabeled_val_split: how many images from the unlabeled training split to use for validation diff --git a/pl_bolts/datamodules/vocdetection_datamodule.py b/pl_bolts/datamodules/vocdetection_datamodule.py index c17d5fedfe..580aca09f2 100644 --- a/pl_bolts/datamodules/vocdetection_datamodule.py +++ b/pl_bolts/datamodules/vocdetection_datamodule.py @@ -103,6 +103,10 @@ def _prepare_voc_instance(image, target): class VOCDetectionDataModule(LightningDataModule): + """ + TODO(teddykoker) docstring + """ + name = "vocdetection" def __init__( @@ -114,9 +118,6 @@ def __init__( *args, **kwargs, ): - """ - TODO(teddykoker) docstring - """ super().__init__(*args, **kwargs) if not _TORCHVISION_AVAILABLE: diff --git a/pl_bolts/datasets/cifar10_dataset.py b/pl_bolts/datasets/cifar10_dataset.py index a1b808b219..51af47530e 100644 --- a/pl_bolts/datasets/cifar10_dataset.py +++ b/pl_bolts/datasets/cifar10_dataset.py @@ -165,18 +165,6 @@ class TrialCIFAR10(CIFAR10): """ Customized `CIFAR10 `_ dataset for testing Pytorch Lightning without the torchvision dependency. - - Args: - data_dir: Root directory of dataset where ``CIFAR10/processed/training.pt`` - and ``CIFAR10/processed/test.pt`` exist. - train: If ``True``, creates dataset from ``training.pt``, - otherwise from ``test.pt``. - download: If true, downloads the dataset from the internet and - puts it in root directory. If dataset is already downloaded, it is not - downloaded again. - num_samples: number of examples per selected class/digit - labels: list selected CIFAR10 digits/classes - Examples: >>> dataset = TrialCIFAR10(download=True, num_samples=150, labels=(1, 5, 8)) @@ -200,6 +188,18 @@ def __init__( labels: Optional[Sequence] = (1, 5, 8), relabel: bool = True, ): + """ + Args: + data_dir: Root directory of dataset where ``CIFAR10/processed/training.pt`` + and ``CIFAR10/processed/test.pt`` exist. + train: If ``True``, creates dataset from ``training.pt``, + otherwise from ``test.pt``. + download: If true, downloads the dataset from the internet and + puts it in root directory. If dataset is already downloaded, it is not + downloaded again. + num_samples: number of examples per selected class/digit + labels: list selected CIFAR10 digits/classes + """ # number of examples per class self.num_samples = num_samples # take just a subset of CIFAR dataset diff --git a/pl_bolts/datasets/dummy_dataset.py b/pl_bolts/datasets/dummy_dataset.py index 44b728422e..30221156bb 100644 --- a/pl_bolts/datasets/dummy_dataset.py +++ b/pl_bolts/datasets/dummy_dataset.py @@ -6,10 +6,6 @@ class DummyDataset(Dataset): """ Generate a dummy dataset - Args: - *shapes: list of shapes - num_samples: how many samples to use in this dataset - Example:: from pl_bolts.datasets import DummyDataset @@ -26,6 +22,11 @@ class DummyDataset(Dataset): torch.Size([7, 1]) """ def __init__(self, *shapes, num_samples: int = 10000): + """ + Args: + *shapes: list of shapes + num_samples: how many samples to use in this dataset + """ super().__init__() self.shapes = shapes self.num_samples = num_samples @@ -45,10 +46,6 @@ class DummyDetectionDataset(Dataset): """ Generate a dummy dataset for detection - Args: - *shapes: list of shapes - num_samples: how many samples to use in this dataset - Example:: from pl_bolts.datasets import DummyDetectionDataset @@ -59,6 +56,11 @@ class DummyDetectionDataset(Dataset): def __init__( self, img_shape: tuple = (3, 256, 256), num_boxes: int = 1, num_classes: int = 2, num_samples: int = 10000 ): + """ + Args: + *shapes: list of shapes + num_samples: how many samples to use in this dataset + """ super().__init__() self.img_shape = img_shape self.num_samples = num_samples @@ -85,10 +87,6 @@ class RandomDictDataset(Dataset): """ Generate a dummy dataset with a dict structure - Args: - size: tuple - num_samples: number of samples - Example:: from pl_bolts.datasets import RandomDictDataset @@ -97,6 +95,11 @@ class RandomDictDataset(Dataset): >>> dl = DataLoader(ds, batch_size=7) """ def __init__(self, size: int, num_samples: int = 250): + """ + Args: + size: tuple + num_samples: number of samples + """ self.len = num_samples self.data = torch.randn(num_samples, size) @@ -113,10 +116,6 @@ class RandomDictStringDataset(Dataset): """ Generate a dummy dataset with strings - Args: - size: tuple - num_samples: number of samples - Example:: from pl_bolts.datasets import RandomDictStringDataset @@ -125,6 +124,11 @@ class RandomDictStringDataset(Dataset): >>> dl = DataLoader(ds, batch_size=7) """ def __init__(self, size: int, num_samples: int = 250): + """ + Args: + size: tuple + num_samples: number of samples + """ self.len = num_samples self.data = torch.randn(num_samples, size) @@ -139,10 +143,6 @@ class RandomDataset(Dataset): """ Generate a dummy dataset - Args: - size: tuple - num_samples: number of samples - Example:: from pl_bolts.datasets import RandomDataset @@ -151,6 +151,11 @@ class RandomDataset(Dataset): >>> dl = DataLoader(ds, batch_size=7) """ def __init__(self, size: int, num_samples: int = 250): + """ + Args: + size: tuple + num_samples: number of samples + """ self.len = num_samples self.data = torch.randn(num_samples, size) diff --git a/pl_bolts/datasets/imagenet_dataset.py b/pl_bolts/datasets/imagenet_dataset.py index 4a6cbd6677..9f61c6ce77 100644 --- a/pl_bolts/datasets/imagenet_dataset.py +++ b/pl_bolts/datasets/imagenet_dataset.py @@ -27,6 +27,12 @@ class UnlabeledImagenet(ImageNet): + """ + Official train set gets split into train, val. (using nb_imgs_per_val_class for each class). + Official validation becomes test set + + Within each class, we further allow limiting the number of samples per class (for semi-sup lng) + """ def __init__( self, @@ -39,11 +45,6 @@ def __init__( **kwargs, ): """ - Official train set gets split into train, val. (using nb_imgs_per_val_class for each class). - Official validation becomes test set - - Within each class, we further allow limiting the number of samples per class (for semi-sup lng) - Args: root: path of dataset split: diff --git a/pl_bolts/datasets/kitti_dataset.py b/pl_bolts/datasets/kitti_dataset.py index f887eb82b8..4fd6e4bdc5 100644 --- a/pl_bolts/datasets/kitti_dataset.py +++ b/pl_bolts/datasets/kitti_dataset.py @@ -21,12 +21,6 @@ class KittiDataset(Dataset): (250 by default). It also sets all of the valid pixels to the appropriate value between 0 and `len(valid_labels)` (since that is the number of valid classes), so it can be used properly by the loss function when comparing with the output. - - Args: - data_dir (str): where to load the data from path, i.e. '/path/to/folder/with/data_semantics/' - img_size: image dimensions (width, height) - void_labels: useless classes to be excluded from training - valid_labels: useful classes to include """ IMAGE_PATH = os.path.join('training', 'image_2') MASK_PATH = os.path.join('training', 'semantic') @@ -39,6 +33,13 @@ def __init__( valid_labels: list = DEFAULT_VALID_LABELS, transform=None ): + """ + Args: + data_dir (str): where to load the data from path, i.e. '/path/to/folder/with/data_semantics/' + img_size: image dimensions (width, height) + void_labels: useless classes to be excluded from training + valid_labels: useful classes to include + """ self.img_size = img_size self.void_labels = void_labels self.valid_labels = valid_labels diff --git a/pl_bolts/datasets/ssl_amdim_datasets.py b/pl_bolts/datasets/ssl_amdim_datasets.py index 8fdcba8b62..d9bb17942b 100644 --- a/pl_bolts/datasets/ssl_amdim_datasets.py +++ b/pl_bolts/datasets/ssl_amdim_datasets.py @@ -23,10 +23,6 @@ class SSLDatasetMixin(ABC): def generate_train_val_split(cls, examples, labels, pct_val): """ Splits dataset uniformly across classes - :param examples: - :param labels: - :param pct_val: - :return: """ nb_classes = len(set(labels)) @@ -58,10 +54,6 @@ def select_nb_imgs_per_class(cls, examples, labels, nb_imgs_in_val): """ Splits a dataset into two parts. The labeled split has nb_imgs_in_val per class - :param examples: - :param labels: - :param nb_imgs_in_val: - :return: """ nb_classes = len(set(labels)) diff --git a/pl_bolts/losses/self_supervised_learning.py b/pl_bolts/losses/self_supervised_learning.py index e6348e7a05..591c755f4e 100644 --- a/pl_bolts/losses/self_supervised_learning.py +++ b/pl_bolts/losses/self_supervised_learning.py @@ -92,22 +92,24 @@ def forward(self, Z): class AmdimNCELoss(nn.Module): + """ + Compute the NCE scores for predicting r_src->r_trg. + """ def __init__(self, tclip): super().__init__() self.tclip = tclip def forward(self, anchor_representations, positive_representations, mask_mat): """ - Compute the NCE scores for predicting r_src->r_trg. Args: - anchor_representations : (batch_size, emb_dim) - positive_representations : (emb_dim, n_batch * w* h) (ie: nb_feat_vectors x embedding_dim) - mask_mat : (n_batch_gpu, n_batch) + anchor_representations: (batch_size, emb_dim) + positive_representations: (emb_dim, n_batch * w* h) (ie: nb_feat_vectors x embedding_dim) + mask_mat: (n_batch_gpu, n_batch) Output: - raw_scores : (n_batch_gpu, n_locs) - nce_scores : (n_batch_gpu, n_locs) - lgt_reg : scalar + raw_scores: (n_batch_gpu, n_locs) + nce_scores: (n_batch_gpu, n_locs) + lgt_reg : scalar """ r_src = anchor_representations r_trg = positive_representations @@ -140,11 +142,11 @@ def forward(self, anchor_representations, positive_representations, mask_mat): # trick 2: tanh clip raw_scores = tanh_clip(raw_scores, clip_val=self.tclip) - ''' + """ pos_scores includes scores for all the positive samples neg_scores includes scores for all the negative samples, with scores for positive samples set to the min score (-self.tclip here) - ''' + """ # ---------------------- # EXTRACT POSITIVE SCORES # use the index mask to pull all the diagonals which are b1 x b1 @@ -186,44 +188,44 @@ def forward(self, anchor_representations, positive_representations, mask_mat): class FeatureMapContrastiveTask(nn.Module): + """ + Performs an anchor, positive negative pair comparison for each each tuple of feature maps passed. - def __init__(self, comparisons: str = '00, 11', tclip: float = 10.0, bidirectional: bool = True): - """ - Performs an anchor, positive negative pair comparison for each each tuple of feature maps passed. + .. code-block:: python + + # extract feature maps + pos_0, pos_1, pos_2 = encoder(x_pos) + anc_0, anc_1, anc_2 = encoder(x_anchor) - .. code-block:: python + # compare only the 0th feature maps + task = FeatureMapContrastiveTask('00') + loss, regularizer = task((pos_0), (anc_0)) - # extract feature maps - pos_0, pos_1, pos_2 = encoder(x_pos) - anc_0, anc_1, anc_2 = encoder(x_anchor) + # compare (pos_0 to anc_1) and (pos_0, anc_2) + task = FeatureMapContrastiveTask('01, 02') + losses, regularizer = task((pos_0, pos_1, pos_2), (anc_0, anc_1, anc_2)) + loss = losses.sum() - # compare only the 0th feature maps - task = FeatureMapContrastiveTask('00') - loss, regularizer = task((pos_0), (anc_0)) + # compare (pos_1 vs a anc_random) + task = FeatureMapContrastiveTask('0r') + loss, regularizer = task((pos_0, pos_1, pos_2), (anc_0, anc_1, anc_2)) - # compare (pos_0 to anc_1) and (pos_0, anc_2) - task = FeatureMapContrastiveTask('01, 02') - losses, regularizer = task((pos_0, pos_1, pos_2), (anc_0, anc_1, anc_2)) - loss = losses.sum() + .. code-block:: python - # compare (pos_1 vs a anc_random) - task = FeatureMapContrastiveTask('0r') - loss, regularizer = task((pos_0, pos_1, pos_2), (anc_0, anc_1, anc_2)) + # with bidirectional the comparisons are done both ways + task = FeatureMapContrastiveTask('01, 02') + # will compare the following: + # 01: (pos_0, anc_1), (anc_0, pos_1) + # 02: (pos_0, anc_2), (anc_0, pos_2) + """ + + def __init__(self, comparisons: str = '00, 11', tclip: float = 10.0, bidirectional: bool = True): + """ Args: comparisons: groupings of feature map indices to compare (zero indexed, 'r' means random) ex: '00, 1r' tclip: stability clipping value bidirectional: if true, does the comparison both ways - - .. code-block:: python - - # with bidirectional the comparisons are done both ways - task = FeatureMapContrastiveTask('01, 02') - - # will compare the following: - # 01: (pos_0, anc_1), (anc_0, pos_1) - # 02: (pos_0, anc_2), (anc_0, pos_2) - """ super().__init__() self.tclip = tclip diff --git a/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py b/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py index 66c3a3a113..d98c7d1416 100644 --- a/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py +++ b/pl_bolts/models/autoencoders/basic_ae/basic_ae_module.py @@ -10,6 +10,19 @@ class AE(pl.LightningModule): + """ + Standard AE + + Model is available pretrained on different datasets: + + Example:: + + # not pretrained + ae = AE() + + # pretrained on cifar10 + ae = AE.from_pretrained('cifar10-resnet18') + """ pretrained_urls = { 'cifar10-resnet18': @@ -29,20 +42,7 @@ def __init__( **kwargs ): """ - Standard AE - - Model is available pretrained on different datasets: - - Example:: - - # not pretrained - ae = AE() - - # pretrained on cifar10 - ae = AE.from_pretrained('cifar10-resnet18') - Args: - input_height: height of the images enc_type: option between resnet18 or resnet50 first_conv: use standard kernel_size 7, stride 2 at start or diff --git a/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py b/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py index ab6671d000..dc5768a28e 100644 --- a/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py +++ b/pl_bolts/models/autoencoders/basic_vae/basic_vae_module.py @@ -10,6 +10,22 @@ class VAE(pl.LightningModule): + """ + Standard VAE with Gaussian Prior and approx posterior. + + Model is available pretrained on different datasets: + + Example:: + + # not pretrained + vae = VAE() + + # pretrained on cifar10 + vae = VAE.from_pretrained('cifar10-resnet18') + + # pretrained on stl10 + vae = VAE.from_pretrained('stl10-resnet18') + """ pretrained_urls = { 'cifar10-resnet18': @@ -31,23 +47,7 @@ def __init__( **kwargs ): """ - Standard VAE with Gaussian Prior and approx posterior. - - Model is available pretrained on different datasets: - - Example:: - - # not pretrained - vae = VAE() - - # pretrained on cifar10 - vae = VAE.from_pretrained('cifar10-resnet18') - - # pretrained on stl10 - vae = VAE.from_pretrained('stl10-resnet18') - Args: - input_height: height of the images enc_type: option between resnet18 or resnet50 first_conv: use standard kernel_size 7, stride 2 at start or diff --git a/pl_bolts/models/detection/faster_rcnn.py b/pl_bolts/models/detection/faster_rcnn.py index 4a6173f982..c6fbb819e8 100644 --- a/pl_bolts/models/detection/faster_rcnn.py +++ b/pl_bolts/models/detection/faster_rcnn.py @@ -24,6 +24,24 @@ def _evaluate_iou(target, pred): class FasterRCNN(pl.LightningModule): + """ + PyTorch Lightning implementation of `Faster R-CNN: Towards Real-Time Object Detection with + Region Proposal Networks `_. + + Paper authors: Shaoqing Ren, Kaiming He, Ross Girshick, Jian Sun + + Model implemented by: + - `Teddy Koker ` + + During training, the model expects both the input tensors, as well as targets (list of dictionary), containing: + - boxes (`FloatTensor[N, 4]`): the ground truth boxes in `[x1, y1, x2, y2]` format. + - labels (`Int64Tensor[N]`): the class label for each ground truh box + + CLI command:: + + # PascalVOC + python faster_rcnn.py --gpus 1 --pretrained True + """ def __init__( self, learning_rate: float = 0.0001, @@ -35,23 +53,6 @@ def __init__( **kwargs, ): """ - PyTorch Lightning implementation of `Faster R-CNN: Towards Real-Time Object Detection with - Region Proposal Networks `_. - - Paper authors: Shaoqing Ren, Kaiming He, Ross Girshick, Jian Sun - - Model implemented by: - - `Teddy Koker ` - - During training, the model expects both the input tensors, as well as targets (list of dictionary), containing: - - boxes (`FloatTensor[N, 4]`): the ground truth boxes in `[x1, y1, x2, y2]` format. - - labels (`Int64Tensor[N]`): the class label for each ground truh box - - CLI command:: - - # PascalVOC - python faster_rcnn.py --gpus 1 --pretrained True - Args: learning_rate: the learning rate num_classes: number of detection classes (including background) diff --git a/pl_bolts/models/gans/basic/basic_gan_module.py b/pl_bolts/models/gans/basic/basic_gan_module.py index 7311cb260b..365c5d0fc8 100644 --- a/pl_bolts/models/gans/basic/basic_gan_module.py +++ b/pl_bolts/models/gans/basic/basic_gan_module.py @@ -8,6 +8,26 @@ class GAN(pl.LightningModule): + """ + Vanilla GAN implementation. + + Example:: + + from pl_bolts.models.gan import GAN + + m = GAN() + Trainer(gpus=2).fit(m) + + Example CLI:: + + # mnist + python basic_gan_module.py --gpus 1 + + # imagenet + python basic_gan_module.py --gpus 1 --dataset 'imagenet2012' + --data_dir /path/to/imagenet/folder/ --meta_dir ~/path/to/meta/bin/folder + --batch_size 256 --learning_rate 0.0001 + """ def __init__( self, @@ -19,34 +39,13 @@ def __init__( **kwargs ): """ - Vanilla GAN implementation. - - Example:: - - from pl_bolts.models.gan import GAN - - m = GAN() - Trainer(gpus=2).fit(m) - - Example CLI:: - - # mnist - python basic_gan_module.py --gpus 1 - - # imagenet - python basic_gan_module.py --gpus 1 --dataset 'imagenet2012' - --data_dir /path/to/imagenet/folder/ --meta_dir ~/path/to/meta/bin/folder - --batch_size 256 --learning_rate 0.0001 - Args: - datamodule: the datamodule (train, val, test splits) latent_dim: emb dim for encoder batch_size: the batch size learning_rate: the learning rate data_dir: where to store data num_workers: data workers - """ super().__init__() diff --git a/pl_bolts/models/regression/linear_regression.py b/pl_bolts/models/regression/linear_regression.py index 241ff9f33e..33012d56ef 100644 --- a/pl_bolts/models/regression/linear_regression.py +++ b/pl_bolts/models/regression/linear_regression.py @@ -9,6 +9,10 @@ class LinearRegression(pl.LightningModule): + """ + Linear regression model implementing - with optional L1/L2 regularization + $$min_{W} ||(Wx + b) - y ||_2^2 $$ + """ def __init__(self, input_dim: int, @@ -20,9 +24,6 @@ def __init__(self, l2_strength: float = 0.0, **kwargs): """ - Linear regression model implementing - with optional L1/L2 regularization - $$min_{W} ||(Wx + b) - y ||_2^2 $$ - Args: input_dim: number of dimensions of the input (1+) output_dim: number of dimensions of the output (default=1) @@ -31,7 +32,6 @@ def __init__(self, optimizer: the optimizer to use (default='Adam') l1_strength: L1 regularization strength (default=None) l2_strength: L2 regularization strength (default=None) - """ super().__init__() self.save_hyperparameters() diff --git a/pl_bolts/models/regression/logistic_regression.py b/pl_bolts/models/regression/logistic_regression.py index d39949c5a1..e4ee2cd3bd 100644 --- a/pl_bolts/models/regression/logistic_regression.py +++ b/pl_bolts/models/regression/logistic_regression.py @@ -10,6 +10,9 @@ class LogisticRegression(pl.LightningModule): + """ + Logistic regression model + """ def __init__(self, input_dim: int, @@ -21,8 +24,6 @@ def __init__(self, l2_strength: float = 0.0, **kwargs): """ - Logistic regression model - Args: input_dim: number of dimensions of the input (at least 1) num_classes: number of class labels (binary: 2, multi-class: >2) @@ -31,7 +32,6 @@ def __init__(self, optimizer: the optimizer to use (default='Adam') l1_strength: L1 regularization strength (default=None) l2_strength: L2 regularization strength (default=None) - """ super().__init__() self.save_hyperparameters() diff --git a/pl_bolts/models/rl/common/memory.py b/pl_bolts/models/rl/common/memory.py index 0cd058ee43..2bc704ee4d 100644 --- a/pl_bolts/models/rl/common/memory.py +++ b/pl_bolts/models/rl/common/memory.py @@ -15,11 +15,13 @@ class Buffer: """ Basic Buffer for storing a single experience at a time - Args: - capacity: size of the buffer """ def __init__(self, capacity: int) -> None: + """ + Args: + capacity: size of the buffer + """ self.buffer = deque(maxlen=capacity) def __len__(self) -> None: @@ -86,14 +88,15 @@ def sample(self, batch_size: int) -> Tuple: class MultiStepBuffer(ReplayBuffer): """ N Step Replay Buffer - - Args: - capacity: max number of experiences that will be stored in the buffer - n_steps: number of steps used for calculating discounted reward/experience - gamma: discount factor when calculating n_step discounted reward of the experience being stored in buffer """ def __init__(self, capacity: int, n_steps: int = 1, gamma: float = 0.99) -> None: + """ + Args: + capacity: max number of experiences that will be stored in the buffer + n_steps: number of steps used for calculating discounted reward/experience + gamma: discount factor when calculating n_step discounted reward of the experience being stored in buffer + """ super().__init__(capacity) self.n_steps = n_steps diff --git a/pl_bolts/models/rl/common/networks.py b/pl_bolts/models/rl/common/networks.py index 4776424d39..1f7398055e 100644 --- a/pl_bolts/models/rl/common/networks.py +++ b/pl_bolts/models/rl/common/networks.py @@ -14,12 +14,14 @@ class CNN(nn.Module): """ Simple MLP network - Args: - input_shape: observation shape of the environment - n_actions: number of discrete actions available in the environment """ def __init__(self, input_shape, n_actions): + """ + Args: + input_shape: observation shape of the environment + n_actions: number of discrete actions available in the environment + """ super(CNN, self).__init__() self.conv = nn.Sequential( @@ -62,13 +64,15 @@ def forward(self, input_x) -> Tensor: class MLP(nn.Module): """ Simple MLP network - Args: - input_shape: observation shape of the environment - n_actions: number of discrete actions available in the environment - hidden_size: size of hidden layers """ def __init__(self, input_shape: Tuple, n_actions: int, hidden_size: int = 128): + """ + Args: + input_shape: observation shape of the environment + n_actions: number of discrete actions available in the environment + hidden_size: size of hidden layers + """ super(MLP, self).__init__() self.net = nn.Sequential( nn.Linear(input_shape[0], hidden_size), @@ -90,13 +94,15 @@ def forward(self, input_x): class DuelingMLP(nn.Module): """ MLP network with duel heads for val and advantage - Args: - input_shape: observation shape of the environment - n_actions: number of discrete actions available in the environment - hidden_size: size of hidden layers """ def __init__(self, input_shape: Tuple, n_actions: int, hidden_size: int = 128): + """ + Args: + input_shape: observation shape of the environment + n_actions: number of discrete actions available in the environment + hidden_size: size of hidden layers + """ super(DuelingMLP, self).__init__() self.net = nn.Sequential( @@ -143,14 +149,15 @@ def adv_val(self, input_x) -> Tuple[Tensor, Tensor]: class DuelingCNN(nn.Module): """ CNN network with duel heads for val and advantage - Args: - input_shape: observation shape of the environment - n_actions: number of discrete actions available in the environment - hidden_size: size of hidden layers """ def __init__(self, input_shape: Tuple, n_actions: int, _: int = 128): - + """ + Args: + input_shape: observation shape of the environment + n_actions: number of discrete actions available in the environment + hidden_size: size of hidden layers + """ super().__init__() self.conv = nn.Sequential( @@ -214,12 +221,14 @@ def adv_val(self, input_x): class NoisyCNN(nn.Module): """ CNN with Noisy Linear layers for exploration - Args: - input_shape: observation shape of the environment - n_actions: number of discrete actions available in the environment """ def __init__(self, input_shape, n_actions): + """ + Args: + input_shape: observation shape of the environment + n_actions: number of discrete actions available in the environment + """ super().__init__() self.conv = nn.Sequential( @@ -269,14 +278,16 @@ class NoisyLinear(nn.Linear): Noisy Layer using Independent Gaussian Noise. based on https://github.com/PacktPublishing/Deep-Reinforcement-Learning-Hands-On-Second-Edition/blob/master/ Chapter08/lib/dqn_extra.py#L19 - Args: - in_features: number of inputs - out_features: number of outputs - sigma_init: initial fill value of noisy weights - bias: flag to include bias to linear layer """ def __init__(self, in_features, out_features, sigma_init=0.017, bias=True): + """ + Args: + in_features: number of inputs + out_features: number of outputs + sigma_init: initial fill value of noisy weights + bias: flag to include bias to linear layer + """ super(NoisyLinear, self).__init__(in_features, out_features, bias=bias) weights = torch.full((out_features, in_features), sigma_init) diff --git a/pl_bolts/models/rl/dqn_model.py b/pl_bolts/models/rl/dqn_model.py index 64bf6e7ce7..f7ab1fc40b 100644 --- a/pl_bolts/models/rl/dqn_model.py +++ b/pl_bolts/models/rl/dqn_model.py @@ -33,7 +33,33 @@ class DQN(pl.LightningModule): - """ Basic DQN Model """ + """ + Basic DQN Model + + PyTorch Lightning implementation of `DQN `_ + Paper authors: Volodymyr Mnih, Koray Kavukcuoglu, David Silver, Alex Graves, + Ioannis Antonoglou, Daan Wierstra, Martin Riedmiller. + Model implemented by: + + - `Donal Byrne ` + + Example: + >>> from pl_bolts.models.rl.dqn_model import DQN + ... + >>> model = DQN("PongNoFrameskip-v4") + + Train:: + + trainer = Trainer() + trainer.fit(model) + + Note: + This example is based on: + https://github.com/PacktPublishing/Deep-Reinforcement-Learning-Hands-On-Second-Edition/blob/master/Chapter06/02_dqn_pong.py + + Note: + Currently only supports CPU and single GPU training with `distributed_backend=dp` + """ def __init__( self, @@ -55,23 +81,6 @@ def __init__( **kwargs, ): """ - PyTorch Lightning implementation of `DQN `_ - Paper authors: Volodymyr Mnih, Koray Kavukcuoglu, David Silver, Alex Graves, - Ioannis Antonoglou, Daan Wierstra, Martin Riedmiller. - Model implemented by: - - - `Donal Byrne ` - - Example: - >>> from pl_bolts.models.rl.dqn_model import DQN - ... - >>> model = DQN("PongNoFrameskip-v4") - - Train:: - - trainer = Trainer() - trainer.fit(model) - Args: env: gym environment tag eps_start: starting value of epsilon for the epsilon-greedy exploration @@ -90,13 +99,6 @@ def __init__( seed: seed value for all RNG used batches_per_epoch: number of batches per epoch n_steps: size of n step look ahead - - Note: - This example is based on: - https://github.com/PacktPublishing/Deep-Reinforcement-Learning-Hands-On-Second-Edition/blob/master/Chapter06/02_dqn_pong.py - - Note: - Currently only supports CPU and single GPU training with `distributed_backend=dp` """ super().__init__() diff --git a/pl_bolts/models/rl/dueling_dqn_model.py b/pl_bolts/models/rl/dueling_dqn_model.py index 79afca2fc7..0c362fafa9 100644 --- a/pl_bolts/models/rl/dueling_dqn_model.py +++ b/pl_bolts/models/rl/dueling_dqn_model.py @@ -11,43 +11,43 @@ class DuelingDQN(DQN): """ - PyTorch Lightning implementation of `Dueling DQN `_ + PyTorch Lightning implementation of `Dueling DQN `_ - Paper authors: Ziyu Wang, Tom Schaul, Matteo Hessel, Hado van Hasselt, Marc Lanctot, Nando de Freitas + Paper authors: Ziyu Wang, Tom Schaul, Matteo Hessel, Hado van Hasselt, Marc Lanctot, Nando de Freitas - Model implemented by: + Model implemented by: - - `Donal Byrne ` + - `Donal Byrne ` - Example: + Example: - >>> from pl_bolts.models.rl.dueling_dqn_model import DuelingDQN - ... - >>> model = DuelingDQN("PongNoFrameskip-v4") + >>> from pl_bolts.models.rl.dueling_dqn_model import DuelingDQN + ... + >>> model = DuelingDQN("PongNoFrameskip-v4") - Train:: + Train:: - trainer = Trainer() - trainer.fit(model) + trainer = Trainer() + trainer.fit(model) - Args: - env: gym environment tag - gpus: number of gpus being used - eps_start: starting value of epsilon for the epsilon-greedy exploration - eps_end: final value of epsilon for the epsilon-greedy exploration - eps_last_frame: the final frame in for the decrease of epsilon. At this frame espilon = eps_end - sync_rate: the number of iterations between syncing up the target network with the train network - gamma: discount factor - lr: learning rate - batch_size: size of minibatch pulled from the DataLoader - replay_size: total capacity of the replay buffer - warm_start_size: how many random steps through the environment to be carried out at the start of - training to fill the buffer with a starting point - sample_len: the number of samples to pull from the dataset iterator and feed to the DataLoader + Args: + env: gym environment tag + gpus: number of gpus being used + eps_start: starting value of epsilon for the epsilon-greedy exploration + eps_end: final value of epsilon for the epsilon-greedy exploration + eps_last_frame: the final frame in for the decrease of epsilon. At this frame espilon = eps_end + sync_rate: the number of iterations between syncing up the target network with the train network + gamma: discount factor + lr: learning rate + batch_size: size of minibatch pulled from the DataLoader + replay_size: total capacity of the replay buffer + warm_start_size: how many random steps through the environment to be carried out at the start of + training to fill the buffer with a starting point + sample_len: the number of samples to pull from the dataset iterator and feed to the DataLoader - .. note:: Currently only supports CPU and single GPU training with `distributed_backend=dp` + .. note:: Currently only supports CPU and single GPU training with `distributed_backend=dp` - """ + """ def build_networks(self) -> None: """Initializes the Dueling DQN train and target networks""" diff --git a/pl_bolts/models/rl/per_dqn_model.py b/pl_bolts/models/rl/per_dqn_model.py index 07ad80d564..9ffbae3cdd 100644 --- a/pl_bolts/models/rl/per_dqn_model.py +++ b/pl_bolts/models/rl/per_dqn_model.py @@ -37,26 +37,26 @@ class PERDQN(DQN): trainer = Trainer() trainer.fit(model) - Args: - env: gym environment tag - gpus: number of gpus being used - eps_start: starting value of epsilon for the epsilon-greedy exploration - eps_end: final value of epsilon for the epsilon-greedy exploration - eps_last_frame: the final frame in for the decrease of epsilon. At this frame espilon = eps_end - sync_rate: the number of iterations between syncing up the target network with the train network - gamma: discount factor - learning_rate: learning rate - batch_size: size of minibatch pulled from the DataLoader - replay_size: total capacity of the replay buffer - warm_start_size: how many random steps through the environment to be carried out at the start of - training to fill the buffer with a starting point - num_samples: the number of samples to pull from the dataset iterator and feed to the DataLoader - - .. note:: - This example is based on: - https://github.com/PacktPublishing/Deep-Reinforcement-Learning-Hands-On-Second-Edition/blob/master/Chapter08/05_dqn_prio_replay.py - - .. note:: Currently only supports CPU and single GPU training with `distributed_backend=dp` + Args: + env: gym environment tag + gpus: number of gpus being used + eps_start: starting value of epsilon for the epsilon-greedy exploration + eps_end: final value of epsilon for the epsilon-greedy exploration + eps_last_frame: the final frame in for the decrease of epsilon. At this frame espilon = eps_end + sync_rate: the number of iterations between syncing up the target network with the train network + gamma: discount factor + learning_rate: learning rate + batch_size: size of minibatch pulled from the DataLoader + replay_size: total capacity of the replay buffer + warm_start_size: how many random steps through the environment to be carried out at the start of + training to fill the buffer with a starting point + num_samples: the number of samples to pull from the dataset iterator and feed to the DataLoader + + .. note:: + This example is based on: + https://github.com/PacktPublishing/Deep-Reinforcement-Learning-Hands-On-Second-Edition/blob/master/Chapter08/05_dqn_prio_replay.py + + .. note:: Currently only supports CPU and single GPU training with `distributed_backend=dp` """ diff --git a/pl_bolts/models/rl/reinforce_model.py b/pl_bolts/models/rl/reinforce_model.py index 0308067ccd..78a5fadd43 100644 --- a/pl_bolts/models/rl/reinforce_model.py +++ b/pl_bolts/models/rl/reinforce_model.py @@ -28,6 +28,32 @@ class Reinforce(pl.LightningModule): + """ + PyTorch Lightning implementation of `REINFORCE + `_ + Paper authors: Richard S. Sutton, David McAllester, Satinder Singh, Yishay Mansour + Model implemented by: + + - `Donal Byrne ` + + Example: + >>> from pl_bolts.models.rl.reinforce_model import Reinforce + ... + >>> model = Reinforce("CartPole-v0") + + Train:: + + trainer = Trainer() + trainer.fit(model) + + Note: + This example is based on: + https://github.com/PacktPublishing/Deep-Reinforcement-Learning-Hands-On-Second-Edition/blob/master/Chapter11/02_cartpole_reinforce.py + + Note: + Currently only supports CPU and single GPU training with `distributed_backend=dp` + """ def __init__( self, env: str, @@ -42,24 +68,6 @@ def __init__( **kwargs ) -> None: """ - PyTorch Lightning implementation of `REINFORCE - `_ - Paper authors: Richard S. Sutton, David McAllester, Satinder Singh, Yishay Mansour - Model implemented by: - - - `Donal Byrne ` - - Example: - >>> from pl_bolts.models.rl.reinforce_model import Reinforce - ... - >>> model = Reinforce("CartPole-v0") - - Train:: - - trainer = Trainer() - trainer.fit(model) - Args: env: gym environment tag gamma: discount factor @@ -70,13 +78,6 @@ def __init__( epoch_len: how many batches before pseudo epoch num_batch_episodes: how many episodes to rollout for each batch of training avg_reward_len: how many episodes to take into account when calculating the avg reward - - Note: - This example is based on: - https://github.com/PacktPublishing/Deep-Reinforcement-Learning-Hands-On-Second-Edition/blob/master/Chapter11/02_cartpole_reinforce.py - - Note: - Currently only supports CPU and single GPU training with `distributed_backend=dp` """ super().__init__() diff --git a/pl_bolts/models/rl/vanilla_policy_gradient_model.py b/pl_bolts/models/rl/vanilla_policy_gradient_model.py index c318d25917..89d9df8cd5 100644 --- a/pl_bolts/models/rl/vanilla_policy_gradient_model.py +++ b/pl_bolts/models/rl/vanilla_policy_gradient_model.py @@ -27,6 +27,31 @@ class VanillaPolicyGradient(pl.LightningModule): + """ + PyTorch Lightning implementation of `Vanilla Policy Gradient + `_ + Paper authors: Richard S. Sutton, David McAllester, Satinder Singh, Yishay Mansour + Model implemented by: + + - `Donal Byrne ` + + Example: + >>> from pl_bolts.models.rl.vanilla_policy_gradient_model import VanillaPolicyGradient + ... + >>> model = VanillaPolicyGradient("CartPole-v0") + + Train:: + trainer = Trainer() + trainer.fit(model) + + Note: + This example is based on: + https://github.com/PacktPublishing/Deep-Reinforcement-Learning-Hands-On-Second-Edition/blob/master/Chapter11/04_cartpole_pg.py + + Note: + Currently only supports CPU and single GPU training with `distributed_backend=dp` + """ def __init__( self, env: str, @@ -40,23 +65,6 @@ def __init__( **kwargs ) -> None: """ - PyTorch Lightning implementation of `Vanilla Policy Gradient - `_ - Paper authors: Richard S. Sutton, David McAllester, Satinder Singh, Yishay Mansour - Model implemented by: - - - `Donal Byrne ` - - Example: - >>> from pl_bolts.models.rl.vanilla_policy_gradient_model import VanillaPolicyGradient - ... - >>> model = VanillaPolicyGradient("CartPole-v0") - - Train:: - trainer = Trainer() - trainer.fit(model) - Args: env: gym environment tag gamma: discount factor @@ -65,13 +73,6 @@ def __init__( batch_episodes: how many episodes to rollout for each batch of training entropy_beta: dictates the level of entropy per batch avg_reward_len: how many episodes to take into account when calculating the avg reward - - Note: - This example is based on: - https://github.com/PacktPublishing/Deep-Reinforcement-Learning-Hands-On-Second-Edition/blob/master/Chapter11/04_cartpole_pg.py - - Note: - Currently only supports CPU and single GPU training with `distributed_backend=dp` """ super().__init__() diff --git a/pl_bolts/models/self_supervised/amdim/amdim_module.py b/pl_bolts/models/self_supervised/amdim/amdim_module.py index c098801907..db851288f8 100644 --- a/pl_bolts/models/self_supervised/amdim/amdim_module.py +++ b/pl_bolts/models/self_supervised/amdim/amdim_module.py @@ -17,6 +17,28 @@ class AMDIM(pl.LightningModule): + """ + PyTorch Lightning implementation of + `Augmented Multiscale Deep InfoMax (AMDIM) `_ + + Paper authors: Philip Bachman, R Devon Hjelm, William Buchwalter. + + Model implemented by: `William Falcon `_ + + This code is adapted to Lightning using the original author repo + (`the original repo `_). + + Example: + + >>> from pl_bolts.models.self_supervised import AMDIM + ... + >>> model = AMDIM(encoder='resnet18') + + Train:: + + trainer = Trainer() + trainer.fit(model) + """ def __init__( self, @@ -37,27 +59,6 @@ def __init__( **kwargs, ): """ - PyTorch Lightning implementation of - `Augmented Multiscale Deep InfoMax (AMDIM) `_ - - Paper authors: Philip Bachman, R Devon Hjelm, William Buchwalter. - - Model implemented by: `William Falcon `_ - - This code is adapted to Lightning using the original author repo - (`the original repo `_). - - Example: - - >>> from pl_bolts.models.self_supervised import AMDIM - ... - >>> model = AMDIM(encoder='resnet18') - - Train:: - - trainer = Trainer() - trainer.fit(model) - Args: datamodule: A LightningDatamodule encoder: an encoder string or model diff --git a/pl_bolts/models/self_supervised/amdim/networks.py b/pl_bolts/models/self_supervised/amdim/networks.py index fee1bc6e06..66a7fe82c2 100644 --- a/pl_bolts/models/self_supervised/amdim/networks.py +++ b/pl_bolts/models/self_supervised/amdim/networks.py @@ -70,9 +70,9 @@ def __init__(self, dummy_batch, num_channels=3, encoder_feature_dim=64, embeddin ) def init_weights(self, init_scale=1.): - ''' + """ Run custom weight init for modules... - ''' + """ for layer in self.layer_list: if isinstance(layer, (ConvResNxN, ConvResBlock)): layer.init_weights(init_scale) @@ -83,9 +83,9 @@ def init_weights(self, init_scale=1.): layer.init_weights(init_scale) def _config_modules(self, x, output_widths, n_rkhs, use_bn): - ''' + """ Configure the modules for extracting fake rkhs embeddings for infomax. - ''' + """ # get activations from each block to see output dims enc_acts = self._forward_acts(x) @@ -110,9 +110,9 @@ def _config_modules(self, x, output_widths, n_rkhs, use_bn): self.rkhs_block_7 = FakeRKHSConvNet(ndf_7, n_rkhs, use_bn) def _forward_acts(self, x): - ''' + """ Return activations from all layers. - ''' + """ # run forward pass through all layers layer_acts = [x] for _, layer in enumerate(self.layer_list): @@ -177,9 +177,9 @@ def __init__(self, n_in, n_out, width, stride, pad, depth, use_bn): self.layer_list = nn.Sequential(*layer_list) def init_weights(self, init_scale=1.): - ''' + """ Do a fixup-ish init for each ConvResNxN in this block. - ''' + """ for m in self.layer_list: m.init_weights(init_scale) diff --git a/pl_bolts/models/self_supervised/amdim/transforms.py b/pl_bolts/models/self_supervised/amdim/transforms.py index bb146c338d..cd30aa3972 100644 --- a/pl_bolts/models/self_supervised/amdim/transforms.py +++ b/pl_bolts/models/self_supervised/amdim/transforms.py @@ -13,26 +13,27 @@ class AMDIMTrainTransformsCIFAR10: - def __init__(self): - """ - Transforms applied to AMDIM + """ + Transforms applied to AMDIM - Transforms:: + Transforms:: - img_jitter, - col_jitter, - rnd_gray, - transforms.ToTensor(), - normalize + img_jitter, + col_jitter, + rnd_gray, + transforms.ToTensor(), + normalize - Example:: + Example:: - x = torch.rand(5, 3, 32, 32) + x = torch.rand(5, 3, 32, 32) - transform = AMDIMTrainTransformsCIFAR10() - (view1, view2) = transform(x) + transform = AMDIMTrainTransformsCIFAR10() + (view1, view2) = transform(x) - """ + """ + + def __init__(self): if not _TORCHVISION_AVAILABLE: raise ModuleNotFoundError( # pragma: no-cover 'You want to use `transforms` from `torchvision` which is not installed yet.' @@ -64,22 +65,23 @@ def __call__(self, inp): class AMDIMEvalTransformsCIFAR10: - def __init__(self): - """ - Transforms applied to AMDIM + """ + Transforms applied to AMDIM - Transforms:: + Transforms:: - transforms.ToTensor(), - normalize + transforms.ToTensor(), + normalize - Example:: + Example:: - x = torch.rand(5, 3, 32, 32) + x = torch.rand(5, 3, 32, 32) - transform = AMDIMEvalTransformsCIFAR10() - (view1, view2) = transform(x) - """ + transform = AMDIMEvalTransformsCIFAR10() + (view1, view2) = transform(x) + """ + + def __init__(self): if not _TORCHVISION_AVAILABLE: raise ModuleNotFoundError( # pragma: no-cover 'You want to use `transforms` from `torchvision` which is not installed yet.' @@ -103,25 +105,26 @@ def __call__(self, inp): class AMDIMTrainTransformsSTL10: - def __init__(self, height=64): - """ - Transforms applied to AMDIM + """ + Transforms applied to AMDIM - Transforms:: + Transforms:: - img_jitter, - col_jitter, - rnd_gray, - transforms.ToTensor(), - normalize + img_jitter, + col_jitter, + rnd_gray, + transforms.ToTensor(), + normalize - Example:: + Example:: - x = torch.rand(5, 3, 64, 64) + x = torch.rand(5, 3, 64, 64) - transform = AMDIMTrainTransformsSTL10() - (view1, view2) = transform(x) - """ + transform = AMDIMTrainTransformsSTL10() + (view1, view2) = transform(x) + """ + + def __init__(self, height=64): if not _TORCHVISION_AVAILABLE: raise ModuleNotFoundError( # pragma: no-cover 'You want to use `transforms` from `torchvision` which is not installed yet.' @@ -152,24 +155,25 @@ def __call__(self, inp): class AMDIMEvalTransformsSTL10(object): - def __init__(self, height=64): - """ - Transforms applied to AMDIM + """ + Transforms applied to AMDIM - Transforms:: + Transforms:: - transforms.Resize(height + 6, interpolation=3), - transforms.CenterCrop(height), - transforms.ToTensor(), - normalize + transforms.Resize(height + 6, interpolation=3), + transforms.CenterCrop(height), + transforms.ToTensor(), + normalize + + Example:: - Example:: + x = torch.rand(5, 3, 64, 64) - x = torch.rand(5, 3, 64, 64) + transform = AMDIMTrainTransformsSTL10() + view1 = transform(x) + """ - transform = AMDIMTrainTransformsSTL10() - view1 = transform(x) - """ + def __init__(self, height=64): if not _TORCHVISION_AVAILABLE: raise ModuleNotFoundError( # pragma: no-cover 'You want to use `transforms` from `torchvision` which is not installed yet.' @@ -195,25 +199,26 @@ def __call__(self, inp): class AMDIMTrainTransformsImageNet128(object): - def __init__(self, height=128): - """ - Transforms applied to AMDIM + """ + Transforms applied to AMDIM - Transforms:: + Transforms:: - img_jitter, - col_jitter, - rnd_gray, - transforms.ToTensor(), - normalize + img_jitter, + col_jitter, + rnd_gray, + transforms.ToTensor(), + normalize + + Example:: - Example:: + x = torch.rand(5, 3, 128, 128) - x = torch.rand(5, 3, 128, 128) + transform = AMDIMTrainTransformsSTL10() + (view1, view2) = transform(x) + """ - transform = AMDIMTrainTransformsSTL10() - (view1, view2) = transform(x) - """ + def __init__(self, height=128): if not _TORCHVISION_AVAILABLE: raise ModuleNotFoundError( # pragma: no-cover 'You want to use `transforms` from `torchvision` which is not installed yet.' @@ -244,24 +249,25 @@ def __call__(self, inp): class AMDIMEvalTransformsImageNet128(object): - def __init__(self, height=128): - """ - Transforms applied to AMDIM + """ + Transforms applied to AMDIM - Transforms:: + Transforms:: - transforms.Resize(height + 6, interpolation=3), - transforms.CenterCrop(height), - transforms.ToTensor(), - normalize + transforms.Resize(height + 6, interpolation=3), + transforms.CenterCrop(height), + transforms.ToTensor(), + normalize - Example:: + Example:: - x = torch.rand(5, 3, 128, 128) + x = torch.rand(5, 3, 128, 128) - transform = AMDIMEvalTransformsImageNet128() - view1 = transform(x) - """ + transform = AMDIMEvalTransformsImageNet128() + view1 = transform(x) + """ + + def __init__(self, height=128): if not _TORCHVISION_AVAILABLE: raise ModuleNotFoundError( # pragma: no-cover 'You want to use `transforms` from `torchvision` which is not installed yet.' diff --git a/pl_bolts/models/self_supervised/byol/byol_module.py b/pl_bolts/models/self_supervised/byol/byol_module.py index b983988ea7..ae3e644fd1 100644 --- a/pl_bolts/models/self_supervised/byol/byol_module.py +++ b/pl_bolts/models/self_supervised/byol/byol_module.py @@ -15,71 +15,72 @@ class BYOL(pl.LightningModule): - def __init__(self, - num_classes, - learning_rate: float = 0.2, - weight_decay: float = 1.5e-6, - input_height: int = 32, - batch_size: int = 32, - num_workers: int = 0, - warmup_epochs: int = 10, - max_epochs: int = 1000, - **kwargs): - """ - PyTorch Lightning implementation of `Bootstrap Your Own Latent (BYOL) - `_ - - Paper authors: Jean-Bastien Grill, Florian Strub, Florent Altché, Corentin Tallec, Pierre H. Richemond, \ - Elena Buchatskaya, Carl Doersch, Bernardo Avila Pires, Zhaohan Daniel Guo, Mohammad Gheshlaghi Azar, \ - Bilal Piot, Koray Kavukcuoglu, Rémi Munos, Michal Valko. + """ + PyTorch Lightning implementation of `Bootstrap Your Own Latent (BYOL) + `_ - Model implemented by: - - `Annika Brundyn `_ + Paper authors: Jean-Bastien Grill, Florian Strub, Florent Altché, Corentin Tallec, Pierre H. Richemond, \ + Elena Buchatskaya, Carl Doersch, Bernardo Avila Pires, Zhaohan Daniel Guo, Mohammad Gheshlaghi Azar, \ + Bilal Piot, Koray Kavukcuoglu, Rémi Munos, Michal Valko. - .. warning:: Work in progress. This implementation is still being verified. + Model implemented by: + - `Annika Brundyn `_ - TODOs: - - verify on CIFAR-10 - - verify on STL-10 - - pre-train on imagenet + .. warning:: Work in progress. This implementation is still being verified. - Example:: + TODOs: + - verify on CIFAR-10 + - verify on STL-10 + - pre-train on imagenet - import pytorch_lightning as pl - from pl_bolts.models.self_supervised import BYOL - from pl_bolts.datamodules import CIFAR10DataModule - from pl_bolts.models.self_supervised.simclr.transforms import ( - SimCLREvalDataTransform, SimCLRTrainDataTransform) + Example:: - # model - model = BYOL(num_classes=10) + import pytorch_lightning as pl + from pl_bolts.models.self_supervised import BYOL + from pl_bolts.datamodules import CIFAR10DataModule + from pl_bolts.models.self_supervised.simclr.transforms import ( + SimCLREvalDataTransform, SimCLRTrainDataTransform) - # data - dm = CIFAR10DataModule(num_workers=0) - dm.train_transforms = SimCLRTrainDataTransform(32) - dm.val_transforms = SimCLREvalDataTransform(32) + # model + model = BYOL(num_classes=10) - trainer = pl.Trainer() - trainer.fit(model, dm) + # data + dm = CIFAR10DataModule(num_workers=0) + dm.train_transforms = SimCLRTrainDataTransform(32) + dm.val_transforms = SimCLREvalDataTransform(32) - Train:: + trainer = pl.Trainer() + trainer.fit(model, dm) - trainer = Trainer() - trainer.fit(model) + Train:: - CLI command:: + trainer = Trainer() + trainer.fit(model) - # cifar10 - python byol_module.py --gpus 1 + CLI command:: - # imagenet - python byol_module.py - --gpus 8 - --dataset imagenet2012 - --data_dir /path/to/imagenet/ - --meta_dir /path/to/folder/with/meta.bin/ - --batch_size 32 + # cifar10 + python byol_module.py --gpus 1 + # imagenet + python byol_module.py + --gpus 8 + --dataset imagenet2012 + --data_dir /path/to/imagenet/ + --meta_dir /path/to/folder/with/meta.bin/ + --batch_size 32 + """ + def __init__(self, + num_classes, + learning_rate: float = 0.2, + weight_decay: float = 1.5e-6, + input_height: int = 32, + batch_size: int = 32, + num_workers: int = 0, + warmup_epochs: int = 10, + max_epochs: int = 1000, + **kwargs): + """ Args: datamodule: The datamodule learning_rate: the learning rate diff --git a/pl_bolts/models/self_supervised/cpc/transforms.py b/pl_bolts/models/self_supervised/cpc/transforms.py index fc870a0251..706ad33765 100644 --- a/pl_bolts/models/self_supervised/cpc/transforms.py +++ b/pl_bolts/models/self_supervised/cpc/transforms.py @@ -13,35 +13,35 @@ class CPCTrainTransformsCIFAR10: + """ + Transforms used for CPC: - def __init__(self, patch_size=8, overlap=4): - """ - Transforms used for CPC: - - Args: - - patch_size: size of patches when cutting up the image into overlapping patches - overlap: how much to overlap patches + Transforms:: - Transforms:: + random_flip + img_jitter + col_jitter + rnd_gray + transforms.ToTensor() + normalize + Patchify(patch_size=patch_size, overlap_size=patch_size // 2) - random_flip - img_jitter - col_jitter - rnd_gray - transforms.ToTensor() - normalize - Patchify(patch_size=patch_size, overlap_size=patch_size // 2) + Example:: - Example:: + # in a regular dataset + CIFAR10(..., transforms=CPCTrainTransformsCIFAR10()) - # in a regular dataset - CIFAR10(..., transforms=CPCTrainTransformsCIFAR10()) + # in a DataModule + module = CIFAR10DataModule(PATH) + train_loader = module.train_dataloader(batch_size=32, transforms=CPCTrainTransformsCIFAR10()) - # in a DataModule - module = CIFAR10DataModule(PATH) - train_loader = module.train_dataloader(batch_size=32, transforms=CPCTrainTransformsCIFAR10()) + """ + def __init__(self, patch_size=8, overlap=4): + """ + Args: + patch_size: size of patches when cutting up the image into overlapping patches + overlap: how much to overlap patches """ if not _TORCHVISION_AVAILABLE: raise ModuleNotFoundError( # pragma: no-cover @@ -74,32 +74,32 @@ def __call__(self, inp): class CPCEvalTransformsCIFAR10: + """ + Transforms used for CPC: - def __init__(self, patch_size=8, overlap=4): - """ - Transforms used for CPC: + Transforms:: - Args: + random_flip + transforms.ToTensor() + normalize + Patchify(patch_size=patch_size, overlap_size=overlap) - patch_size: size of patches when cutting up the image into overlapping patches - overlap: how much to overlap patches + Example:: - Transforms:: - - random_flip - transforms.ToTensor() - normalize - Patchify(patch_size=patch_size, overlap_size=overlap) + # in a regular dataset + CIFAR10(..., transforms=CPCEvalTransformsCIFAR10()) - Example:: + # in a DataModule + module = CIFAR10DataModule(PATH) + train_loader = module.train_dataloader(batch_size=32, transforms=CPCEvalTransformsCIFAR10()) - # in a regular dataset - CIFAR10(..., transforms=CPCEvalTransformsCIFAR10()) - - # in a DataModule - module = CIFAR10DataModule(PATH) - train_loader = module.train_dataloader(batch_size=32, transforms=CPCEvalTransformsCIFAR10()) + """ + def __init__(self, patch_size=8, overlap=4): + """ + Args: + patch_size: size of patches when cutting up the image into overlapping patches + overlap: how much to overlap patches """ if not _TORCHVISION_AVAILABLE: raise ModuleNotFoundError( # pragma: no-cover @@ -126,36 +126,34 @@ def __call__(self, inp): class CPCTrainTransformsSTL10: + """ + Transforms used for CPC: - def __init__(self, patch_size=16, overlap=8): - """ - Transforms used for CPC: - - Args: - - patch_size: size of patches when cutting up the image into overlapping patches - overlap: how much to overlap patches - - Transforms:: + Transforms:: - random_flip - img_jitter - col_jitter - rnd_gray - transforms.ToTensor() - normalize - Patchify(patch_size=patch_size, overlap_size=patch_size // 2) + random_flip + img_jitter + col_jitter + rnd_gray + transforms.ToTensor() + normalize + Patchify(patch_size=patch_size, overlap_size=patch_size // 2) - Example:: + Example:: - # in a regular dataset - STL10(..., transforms=CPCTrainTransformsSTL10()) - - # in a DataModule - module = STL10DataModule(PATH) - train_loader = module.train_dataloader(batch_size=32, transforms=CPCTrainTransformsSTL10()) + # in a regular dataset + STL10(..., transforms=CPCTrainTransformsSTL10()) + # in a DataModule + module = STL10DataModule(PATH) + train_loader = module.train_dataloader(batch_size=32, transforms=CPCTrainTransformsSTL10()) + """ + def __init__(self, patch_size=16, overlap=8): + """ + Args: + patch_size: size of patches when cutting up the image into overlapping patches + overlap: how much to overlap patches """ if not _TORCHVISION_AVAILABLE: raise ModuleNotFoundError( # pragma: no-cover @@ -189,32 +187,32 @@ def __call__(self, inp): class CPCEvalTransformsSTL10: + """ + Transforms used for CPC: - def __init__(self, patch_size=16, overlap=8): - """ - Transforms used for CPC: - - Args: + Transforms:: - patch_size: size of patches when cutting up the image into overlapping patches - overlap: how much to overlap patches + random_flip + transforms.ToTensor() + normalize + Patchify(patch_size=patch_size, overlap_size=patch_size // 2) - Transforms:: + Example:: - random_flip - transforms.ToTensor() - normalize - Patchify(patch_size=patch_size, overlap_size=patch_size // 2) + # in a regular dataset + STL10(..., transforms=CPCEvalTransformsSTL10()) - Example:: + # in a DataModule + module = STL10DataModule(PATH) + train_loader = module.train_dataloader(batch_size=32, transforms=CPCEvalTransformsSTL10()) - # in a regular dataset - STL10(..., transforms=CPCEvalTransformsSTL10()) - - # in a DataModule - module = STL10DataModule(PATH) - train_loader = module.train_dataloader(batch_size=32, transforms=CPCEvalTransformsSTL10()) + """ + def __init__(self, patch_size=16, overlap=8): + """ + Args: + patch_size: size of patches when cutting up the image into overlapping patches + overlap: how much to overlap patches """ if not _TORCHVISION_AVAILABLE: raise ModuleNotFoundError( # pragma: no-cover @@ -241,30 +239,30 @@ def __call__(self, inp): class CPCTrainTransformsImageNet128: - def __init__(self, patch_size=32, overlap=16): - """ - Transforms used for CPC: + """ + Transforms used for CPC: - Args: + Transforms:: - patch_size: size of patches when cutting up the image into overlapping patches - overlap: how much to overlap patches + random_flip + transforms.ToTensor() + normalize + Patchify(patch_size=patch_size, overlap_size=patch_size // 2) - Transforms:: + Example:: - random_flip - transforms.ToTensor() - normalize - Patchify(patch_size=patch_size, overlap_size=patch_size // 2) + # in a regular dataset + Imagenet(..., transforms=CPCTrainTransformsImageNet128()) - Example:: - - # in a regular dataset - Imagenet(..., transforms=CPCTrainTransformsImageNet128()) - - # in a DataModule - module = ImagenetDataModule(PATH) - train_loader = module.train_dataloader(batch_size=32, transforms=CPCTrainTransformsImageNet128()) + # in a DataModule + module = ImagenetDataModule(PATH) + train_loader = module.train_dataloader(batch_size=32, transforms=CPCTrainTransformsImageNet128()) + """ + def __init__(self, patch_size=32, overlap=16): + """ + Args: + patch_size: size of patches when cutting up the image into overlapping patches + overlap: how much to overlap patches """ if not _TORCHVISION_AVAILABLE: raise ModuleNotFoundError( # pragma: no-cover @@ -300,30 +298,31 @@ def __call__(self, inp): class CPCEvalTransformsImageNet128: - def __init__(self, patch_size=32, overlap=16): - """ - Transforms used for CPC: + """ + Transforms used for CPC: - Args: + Transforms:: - patch_size: size of patches when cutting up the image into overlapping patches - overlap: how much to overlap patches + random_flip + transforms.ToTensor() + normalize + Patchify(patch_size=patch_size, overlap_size=patch_size // 2) - Transforms:: + Example:: - random_flip - transforms.ToTensor() - normalize - Patchify(patch_size=patch_size, overlap_size=patch_size // 2) + # in a regular dataset + Imagenet(..., transforms=CPCEvalTransformsImageNet128()) - Example:: + # in a DataModule + module = ImagenetDataModule(PATH) + train_loader = module.train_dataloader(batch_size=32, transforms=CPCEvalTransformsImageNet128()) + """ - # in a regular dataset - Imagenet(..., transforms=CPCEvalTransformsImageNet128()) - - # in a DataModule - module = ImagenetDataModule(PATH) - train_loader = module.train_dataloader(batch_size=32, transforms=CPCEvalTransformsImageNet128()) + def __init__(self, patch_size=32, overlap=16): + """ + Args: + patch_size: size of patches when cutting up the image into overlapping patches + overlap: how much to overlap patches """ if not _TORCHVISION_AVAILABLE: raise ModuleNotFoundError( # pragma: no-cover diff --git a/pl_bolts/models/self_supervised/moco/moco2_module.py b/pl_bolts/models/self_supervised/moco/moco2_module.py index 8eeb72c5cf..c67ae835b0 100644 --- a/pl_bolts/models/self_supervised/moco/moco2_module.py +++ b/pl_bolts/models/self_supervised/moco/moco2_module.py @@ -32,6 +32,34 @@ class MocoV2(pl.LightningModule): + """ + PyTorch Lightning implementation of `Moco `_ + + Paper authors: Xinlei Chen, Haoqi Fan, Ross Girshick, Kaiming He. + + Code adapted from `facebookresearch/moco `_ to Lightning by: + + - `William Falcon `_ + + Example:: + from pl_bolts.models.self_supervised import MocoV2 + model = MocoV2() + trainer = Trainer() + trainer.fit(model) + + CLI command:: + + # cifar10 + python moco2_module.py --gpus 1 + + # imagenet + python moco2_module.py + --gpus 8 + --dataset imagenet2012 + --data_dir /path/to/imagenet/ + --meta_dir /path/to/folder/with/meta.bin/ + --batch_size 32 + """ def __init__(self, base_encoder: Union[str, torch.nn.Module] = 'resnet18', @@ -49,33 +77,6 @@ def __init__(self, num_workers: int = 8, *args, **kwargs): """ - PyTorch Lightning implementation of `Moco `_ - - Paper authors: Xinlei Chen, Haoqi Fan, Ross Girshick, Kaiming He. - - Code adapted from `facebookresearch/moco `_ to Lightning by: - - - `William Falcon `_ - - Example:: - from pl_bolts.models.self_supervised import MocoV2 - model = MocoV2() - trainer = Trainer() - trainer.fit(model) - - CLI command:: - - # cifar10 - python moco2_module.py --gpus 1 - - # imagenet - python moco2_module.py - --gpus 8 - --dataset imagenet2012 - --data_dir /path/to/imagenet/ - --meta_dir /path/to/folder/with/meta.bin/ - --batch_size 32 - Args: base_encoder: torchvision model name or torch.nn.Module emb_dim: feature dimension (default: 128) diff --git a/pl_bolts/models/self_supervised/simclr/transforms.py b/pl_bolts/models/self_supervised/simclr/transforms.py index 55b975c66f..bf33c6c277 100644 --- a/pl_bolts/models/self_supervised/simclr/transforms.py +++ b/pl_bolts/models/self_supervised/simclr/transforms.py @@ -103,7 +103,10 @@ def __call__(self, sample): class GaussianBlur(object): - # Implements Gaussian blur as described in the SimCLR paper + """ + Implements Gaussian blur as described in the SimCLR paper + """ + def __init__(self, kernel_size, min=0.1, max=2.0): if not _TORCHVISION_AVAILABLE: raise ModuleNotFoundError( # pragma: no-cover diff --git a/pl_bolts/models/self_supervised/ssl_finetuner.py b/pl_bolts/models/self_supervised/ssl_finetuner.py index f07e697a42..32d022b913 100644 --- a/pl_bolts/models/self_supervised/ssl_finetuner.py +++ b/pl_bolts/models/self_supervised/ssl_finetuner.py @@ -7,38 +7,39 @@ class SSLFineTuner(pl.LightningModule): + """ + Finetunes a self-supervised learning backbone using the standard evaluation protocol of a singler layer MLP + with 1024 units - def __init__(self, backbone, in_features, num_classes, hidden_dim=1024): - """ - Finetunes a self-supervised learning backbone using the standard evaluation protocol of a singler layer MLP - with 1024 units - - Example:: + Example:: - from pl_bolts.utils.self_supervised import SSLFineTuner - from pl_bolts.models.self_supervised import CPCV2 - from pl_bolts.datamodules import CIFAR10DataModule - from pl_bolts.models.self_supervised.cpc.transforms import CPCEvalTransformsCIFAR10, - CPCTrainTransformsCIFAR10 + from pl_bolts.utils.self_supervised import SSLFineTuner + from pl_bolts.models.self_supervised import CPCV2 + from pl_bolts.datamodules import CIFAR10DataModule + from pl_bolts.models.self_supervised.cpc.transforms import CPCEvalTransformsCIFAR10, + CPCTrainTransformsCIFAR10 - # pretrained model - backbone = CPCV2.load_from_checkpoint(PATH, strict=False) + # pretrained model + backbone = CPCV2.load_from_checkpoint(PATH, strict=False) - # dataset + transforms - dm = CIFAR10DataModule(data_dir='.') - dm.train_transforms = CPCTrainTransformsCIFAR10() - dm.val_transforms = CPCEvalTransformsCIFAR10() + # dataset + transforms + dm = CIFAR10DataModule(data_dir='.') + dm.train_transforms = CPCTrainTransformsCIFAR10() + dm.val_transforms = CPCEvalTransformsCIFAR10() - # finetuner - finetuner = SSLFineTuner(backbone, in_features=backbone.z_dim, num_classes=backbone.num_classes) + # finetuner + finetuner = SSLFineTuner(backbone, in_features=backbone.z_dim, num_classes=backbone.num_classes) - # train - trainer = pl.Trainer() - trainer.fit(finetuner, dm) + # train + trainer = pl.Trainer() + trainer.fit(finetuner, dm) - # test - trainer.test(datamodule=dm) + # test + trainer.test(datamodule=dm) + """ + def __init__(self, backbone, in_features, num_classes, hidden_dim=1024): + """ Args: backbone: a pretrained model in_features: feature dim of backbone outputs diff --git a/pl_bolts/models/vision/image_gpt/gpt2.py b/pl_bolts/models/vision/image_gpt/gpt2.py index b1c98c9092..827fef0994 100644 --- a/pl_bolts/models/vision/image_gpt/gpt2.py +++ b/pl_bolts/models/vision/image_gpt/gpt2.py @@ -30,6 +30,28 @@ def forward(self, x): class GPT2(pl.LightningModule): + """ + GPT-2 from `language Models are Unsupervised Multitask Learners `_ + + Paper by: Alec Radford, Jeffrey Wu, Rewon Child, David Luan, Dario Amodei, Ilya Sutskever + + Implementation contributed by: + + - `Teddy Koker `_ + + Example:: + + from pl_bolts.models import GPT2 + + seq_len = 17 + batch_size = 32 + vocab_size = 16 + x = torch.randint(0, vocab_size, (seq_len, batch_size)) + model = GPT2(embed_dim=32, heads=2, layers=2, num_positions=seq_len, vocab_size=vocab_size, num_classes=4) + results = model(x) + """ + def __init__( self, embed_dim: int, @@ -39,27 +61,6 @@ def __init__( vocab_size: int, num_classes: int, ): - """ - GPT-2 from `language Models are Unsupervised Multitask Learners `_ - - Paper by: Alec Radford, Jeffrey Wu, Rewon Child, David Luan, Dario Amodei, Ilya Sutskever - - Implementation contributed by: - - - `Teddy Koker `_ - - Example:: - - from pl_bolts.models import GPT2 - - seq_len = 17 - batch_size = 32 - vocab_size = 16 - x = torch.randint(0, vocab_size, (seq_len, batch_size)) - model = GPT2(embed_dim=32, heads=2, layers=2, num_positions=seq_len, vocab_size=vocab_size, num_classes=4) - results = model(x) - """ super(GPT2, self).__init__() self.save_hyperparameters() diff --git a/pl_bolts/models/vision/image_gpt/igpt_module.py b/pl_bolts/models/vision/image_gpt/igpt_module.py index 4df2e83466..45faf781fe 100644 --- a/pl_bolts/models/vision/image_gpt/igpt_module.py +++ b/pl_bolts/models/vision/image_gpt/igpt_module.py @@ -16,6 +16,86 @@ def _shape_input(x): class ImageGPT(pl.LightningModule): + """ + **Paper**: `Generative Pretraining from Pixels + `_ + [original paper `code `_]. + + **Paper by:** Mark Che, Alec Radford, Rewon Child, Jeff Wu, Heewoo Jun, + Prafulla Dhariwal, David Luan, Ilya Sutskever + + **Implementation contributed by**: + + - `Teddy Koker `_ + + **Original repo with results and more implementation details**: + + - `https://github.com/teddykoker/image-gpt `_ + + **Example Results (Photo credits: Teddy Koker)**: + + .. image:: https://raw.githubusercontent.com/teddykoker/image-gpt/master/figures/mnist.png + :width: 250 + :alt: credit-Teddy-Koker + + .. image:: https://raw.githubusercontent.com/teddykoker/image-gpt/master/figures/fmnist.png + :width: 250 + :alt: credit-Teddy-Koker + + **Default arguments:** + + .. list-table:: Argument Defaults + :widths: 50 25 25 + :header-rows: 1 + + * - Argument + - Default + - iGPT-S (`Chen et al. `_) + * - `--embed_dim` + - 16 + - 512 + * - `--heads` + - 2 + - 8 + * - `--layers` + - 8 + - 24 + * - `--pixels` + - 28 + - 32 + * - `--vocab_size` + - 16 + - 512 + * - `--num_classes` + - 10 + - 10 + * - `--batch_size` + - 64 + - 128 + * - `--learning_rate` + - 0.01 + - 0.01 + * - `--steps` + - 25000 + - 1000000 + + Example:: + + import pytorch_lightning as pl + from pl_bolts.models.vision import ImageGPT + + dm = MNISTDataModule('.') + model = ImageGPT(dm) + + pl.Trainer(gpu=4).fit(model) + + As script: + + .. code-block:: bash + + cd pl_bolts/models/vision/image_gpt + python igpt_module.py --learning_rate 1e-2 --batch_size 32 --gpus 4 + """ def __init__( self, datamodule: pl.LightningDataModule = None, @@ -34,87 +114,7 @@ def __init__( **kwargs, ): """ - **Paper**: `Generative Pretraining from Pixels - `_ - [original paper `code `_]. - - **Paper by:** Mark Che, Alec Radford, Rewon Child, Jeff Wu, Heewoo Jun, - Prafulla Dhariwal, David Luan, Ilya Sutskever - - **Implementation contributed by**: - - - `Teddy Koker `_ - - **Original repo with results and more implementation details**: - - - `https://github.com/teddykoker/image-gpt `_ - - **Example Results (Photo credits: Teddy Koker)**: - - .. image:: https://raw.githubusercontent.com/teddykoker/image-gpt/master/figures/mnist.png - :width: 250 - :alt: credit-Teddy-Koker - - .. image:: https://raw.githubusercontent.com/teddykoker/image-gpt/master/figures/fmnist.png - :width: 250 - :alt: credit-Teddy-Koker - - **Default arguments:** - - .. list-table:: Argument Defaults - :widths: 50 25 25 - :header-rows: 1 - - * - Argument - - Default - - iGPT-S (`Chen et al. `_) - * - `--embed_dim` - - 16 - - 512 - * - `--heads` - - 2 - - 8 - * - `--layers` - - 8 - - 24 - * - `--pixels` - - 28 - - 32 - * - `--vocab_size` - - 16 - - 512 - * - `--num_classes` - - 10 - - 10 - * - `--batch_size` - - 64 - - 128 - * - `--learning_rate` - - 0.01 - - 0.01 - * - `--steps` - - 25000 - - 1000000 - - Example:: - - import pytorch_lightning as pl - from pl_bolts.models.vision import ImageGPT - - dm = MNISTDataModule('.') - model = ImageGPT(dm) - - pl.Trainer(gpu=4).fit(model) - - As script: - - .. code-block:: bash - - cd pl_bolts/models/vision/image_gpt - python igpt_module.py --learning_rate 1e-2 --batch_size 32 --gpus 4 - Args: - datamodule: LightningDataModule embed_dim: the embedding dim heads: number of attention heads diff --git a/pl_bolts/models/vision/pixel_cnn.py b/pl_bolts/models/vision/pixel_cnn.py index d7f2341a4e..7c1f24840c 100644 --- a/pl_bolts/models/vision/pixel_cnn.py +++ b/pl_bolts/models/vision/pixel_cnn.py @@ -9,30 +9,30 @@ class PixelCNN(nn.Module): + """ + Implementation of `Pixel CNN `_. - def __init__(self, input_channels, hidden_channels=256, num_blocks=5): - """ - Implementation of `Pixel CNN `_. + Paper authors: Aaron van den Oord, Nal Kalchbrenner, Oriol Vinyals, Lasse Espeholt, Alex Graves, + Koray Kavukcuoglu - Paper authors: Aaron van den Oord, Nal Kalchbrenner, Oriol Vinyals, Lasse Espeholt, Alex Graves, - Koray Kavukcuoglu + Implemented by: - Implemented by: + - William Falcon - - William Falcon + Example:: - Example:: + >>> from pl_bolts.models.vision import PixelCNN + >>> import torch + ... + >>> model = PixelCNN(input_channels=3) + >>> x = torch.rand(5, 3, 64, 64) + >>> out = model(x) + ... + >>> out.shape + torch.Size([5, 3, 64, 64]) + """ - >>> from pl_bolts.models.vision import PixelCNN - >>> import torch - ... - >>> model = PixelCNN(input_channels=3) - >>> x = torch.rand(5, 3, 64, 64) - >>> out = model(x) - ... - >>> out.shape - torch.Size([5, 3, 64, 64]) - """ + def __init__(self, input_channels, hidden_channels=256, num_blocks=5): super().__init__() self.input_channels = input_channels self.hidden_channels = hidden_channels diff --git a/pl_bolts/models/vision/unet.py b/pl_bolts/models/vision/unet.py index 1f5bfed343..817e4aabec 100644 --- a/pl_bolts/models/vision/unet.py +++ b/pl_bolts/models/vision/unet.py @@ -15,12 +15,6 @@ class UNet(nn.Module): - `Akshay Kulkarni `_ .. warning:: Work in progress. This implementation is still being verified. - - Args: - num_classes: Number of output classes required - num_layers: Number of layers in each side of U-net (default 5) - features_start: Number of features in first layer (default 64) - bilinear (bool): Whether to use bilinear interpolation or transposed convolutions (default) for upsampling. """ def __init__( @@ -30,6 +24,13 @@ def __init__( features_start: int = 64, bilinear: bool = False ): + """ + Args: + num_classes: Number of output classes required + num_layers: Number of layers in each side of U-net (default 5) + features_start: Number of features in first layer (default 64) + bilinear (bool): Whether to use bilinear interpolation or transposed convolutions (default) for upsampling. + """ super().__init__() self.num_layers = num_layers @@ -98,8 +99,7 @@ def forward(self, x): class Up(nn.Module): """ Upsampling (by either bilinear interpolation or transpose convolutions) - followed by concatenation of feature map from contracting path, - followed by DoubleConv. + followed by concatenation of feature map from contracting path, followed by DoubleConv. """ def __init__(self, in_ch: int, out_ch: int, bilinear: bool = False): diff --git a/pl_bolts/optimizers/lars_scheduling.py b/pl_bolts/optimizers/lars_scheduling.py index d78feeebcb..184589903e 100644 --- a/pl_bolts/optimizers/lars_scheduling.py +++ b/pl_bolts/optimizers/lars_scheduling.py @@ -9,10 +9,12 @@ class LARSWrapper(object): + """ + Wrapper that adds LARS scheduling to any optimizer. This helps stability with huge batch sizes. + """ + def __init__(self, optimizer, eta=0.02, clip=True, eps=1e-8): """ - Wrapper that adds LARS scheduling to any optimizer. This helps stability with huge batch sizes. - Args: optimizer: torch optimizer eta: LARS coefficient (trust) diff --git a/pl_bolts/optimizers/lr_scheduler.py b/pl_bolts/optimizers/lr_scheduler.py index a568592316..c4b0bb0442 100644 --- a/pl_bolts/optimizers/lr_scheduler.py +++ b/pl_bolts/optimizers/lr_scheduler.py @@ -25,14 +25,6 @@ class LinearWarmupCosineAnnealingLR(_LRScheduler): epoch param to :func:`.step()`, the user should call the :func:`.step()` function before calling train and validation methods. - Args: - optimizer (Optimizer): Wrapped optimizer. - warmup_epochs (int): Maximum number of iterations for linear warmup - max_epochs (int): Maximum number of iterations - warmup_start_lr (float): Learning rate to start the linear warmup. Default: 0. - eta_min (float): Minimum learning rate. Default: 0. - last_epoch (int): The index of last epoch. Default: -1. - Example: >>> layer = nn.Linear(10, 1) >>> optimizer = Adam(layer.parameters(), lr=0.02) @@ -60,7 +52,15 @@ def __init__( eta_min: float = 0.0, last_epoch: int = -1, ) -> None: - + """ + Args: + optimizer (Optimizer): Wrapped optimizer. + warmup_epochs (int): Maximum number of iterations for linear warmup + max_epochs (int): Maximum number of iterations + warmup_start_lr (float): Learning rate to start the linear warmup. Default: 0. + eta_min (float): Minimum learning rate. Default: 0. + last_epoch (int): The index of last epoch. Default: -1. + """ self.warmup_epochs = warmup_epochs self.max_epochs = max_epochs self.warmup_start_lr = warmup_start_lr diff --git a/pl_bolts/transforms/self_supervised/ssl_transforms.py b/pl_bolts/transforms/self_supervised/ssl_transforms.py index 70e6e762ea..1e51dda4fc 100644 --- a/pl_bolts/transforms/self_supervised/ssl_transforms.py +++ b/pl_bolts/transforms/self_supervised/ssl_transforms.py @@ -14,13 +14,13 @@ class RandomTranslateWithReflect: - ''' + """ Translate image randomly Translate vertically and horizontally by n pixels where n is integer drawn uniformly independently for each axis from [-max_translation, max_translation]. Fill the uncovered blank area with reflect padding. - ''' + """ def __init__(self, max_translation): self.max_translation = max_translation diff --git a/pl_bolts/utils/arguments.py b/pl_bolts/utils/arguments.py index fd4205cd5e..fa0c7e9ec1 100644 --- a/pl_bolts/utils/arguments.py +++ b/pl_bolts/utils/arguments.py @@ -18,20 +18,22 @@ class LitArg: class LightningArgumentParser(ArgumentParser): - def __init__(self, *args, ignore_required_init_args=True, **kwargs): - """Extension of argparse.ArgumentParser that lets you parse arbitrary object init args. - - Example:: + """ + Extension of argparse.ArgumentParser that lets you parse arbitrary object init args. - from pl_bolts.utils.arguments import LightningArgumentParser + Example:: - parser.add_object_args("data", MyDataModule) - parser.add_object_args("model", MyModel) - args = parser.parse_lit_args() + from pl_bolts.utils.arguments import LightningArgumentParser - # args.data -> data args - # args.model -> model args + parser.add_object_args("data", MyDataModule) + parser.add_object_args("model", MyModel) + args = parser.parse_lit_args() + # args.data -> data args + # args.model -> model args + """ + def __init__(self, *args, ignore_required_init_args=True, **kwargs): + """ Args: ignore_required_init_args (bool, optional): Whether to include positional args when adding object args. Defaults to True. From 5caac5b776f40ae9025b238d19e44085edcb58a3 Mon Sep 17 00:00:00 2001 From: deng-cy <40417707+deng-cy@users.noreply.github.com> Date: Thu, 15 Oct 2020 09:05:55 -0400 Subject: [PATCH 27/32] Removed datamodule from an input parameter (#270) * updated * removed datamodule in the model * add trainer to get datamodule.name * changed to lightning 1.0.0 format * iGPT * moco * dm Co-authored-by: Jirka Borovec --- .../self_supervised/moco/moco2_module.py | 26 ++++++++----------- .../models/vision/image_gpt/igpt_module.py | 10 +++---- tests/models/self_supervised/test_models.py | 4 +-- tests/models/test_vision.py | 10 +++---- 4 files changed, 21 insertions(+), 29 deletions(-) diff --git a/pl_bolts/models/self_supervised/moco/moco2_module.py b/pl_bolts/models/self_supervised/moco/moco2_module.py index c67ae835b0..bdc7247717 100644 --- a/pl_bolts/models/self_supervised/moco/moco2_module.py +++ b/pl_bolts/models/self_supervised/moco/moco2_module.py @@ -70,7 +70,6 @@ def __init__(self, learning_rate: float = 0.03, momentum: float = 0.9, weight_decay: float = 1e-4, - datamodule: pl.LightningDataModule = None, data_dir: str = './', batch_size: int = 256, use_mlp: bool = False, @@ -96,14 +95,6 @@ def __init__(self, super().__init__() self.save_hyperparameters() - # use CIFAR-10 by default if no datamodule passed in - # if datamodule is None: - # datamodule = CIFAR10DataModule(data_dir) - # datamodule.train_transforms = Moco2TrainCIFAR10Transforms() - # datamodule.val_transforms = Moco2EvalCIFAR10Transforms() - assert datamodule - self.datamodule = datamodule - # create the encoders # num_classes is the output fc dimension self.encoder_q, self.encoder_k = self.init_encoders(base_encoder) @@ -259,7 +250,7 @@ def forward(self, img_q, img_k): def training_step(self, batch, batch_idx): # in STL10 we pass in both lab+unl for online ft - if self.hparams.datamodule.name == 'stl10': + if self.trainer.datamodule.name == 'stl10': labeled_batch = batch[1] unlabeled_batch = batch[0] batch = unlabeled_batch @@ -276,11 +267,12 @@ def training_step(self, batch, batch_idx): 'train_acc1': acc1, 'train_acc5': acc5 } - return {'loss': loss, 'log': log, 'progress_bar': log} + self.log_dict(log) + return loss def validation_step(self, batch, batch_idx): # in STL10 we pass in both lab+unl for online ft - if self.hparams.datamodule.name == 'stl10': + if self.trainer.datamodule.name == 'stl10': labeled_batch = batch[1] unlabeled_batch = batch[0] batch = unlabeled_batch @@ -309,7 +301,7 @@ def validation_epoch_end(self, outputs): 'val_acc1': val_acc1, 'val_acc5': val_acc5 } - return {'val_loss': val_loss, 'log': log, 'progress_bar': log} + self.log_dict(log) def configure_optimizers(self): optimizer = torch.optim.SGD(self.parameters(), self.hparams.learning_rate, @@ -382,10 +374,14 @@ def cli_main(): datamodule.train_transforms = Moco2TrainImagenetTransforms() datamodule.val_transforms = Moco2EvalImagenetTransforms() - model = MocoV2(**args.__dict__, datamodule=datamodule) + else: + # replace with your own dataset, otherwise CIFAR-10 will be used by default if `None` passed in + datamodule = None + + model = MocoV2(**args.__dict__) trainer = pl.Trainer.from_argparse_args(args) - trainer.fit(model) + trainer.fit(model, datamodule=datamodule) if __name__ == '__main__': diff --git a/pl_bolts/models/vision/image_gpt/igpt_module.py b/pl_bolts/models/vision/image_gpt/igpt_module.py index 45faf781fe..ad2d70c331 100644 --- a/pl_bolts/models/vision/image_gpt/igpt_module.py +++ b/pl_bolts/models/vision/image_gpt/igpt_module.py @@ -98,7 +98,6 @@ class ImageGPT(pl.LightningModule): """ def __init__( self, - datamodule: pl.LightningDataModule = None, embed_dim: int = 16, heads: int = 2, layers: int = 2, @@ -115,7 +114,6 @@ def __init__( ): """ Args: - datamodule: LightningDataModule embed_dim: the embedding dim heads: number of attention heads layers: number of layers @@ -129,7 +127,7 @@ def __init__( data_dir: where to store data num_workers: num_data workers """ - super(ImageGPT, self).__init__() + super().__init__() self.save_hyperparameters() # default to MNIST if no datamodule given @@ -139,8 +137,6 @@ def __init__( # ) # self.hparams.pixels = datamodule.size(1) # self.hparams.num_classes = datamodule.num_classes - assert datamodule - self.datamodule = datamodule self.gpt = GPT2( embed_dim=self.hparams.embed_dim, @@ -259,10 +255,10 @@ def cli_main(): elif args.dataset == "imagenet128": datamodule = ImagenetDataModule.from_argparse_args(args) - model = ImageGPT(**args.__dict__, datamodule=datamodule) + model = ImageGPT(**args.__dict__) trainer = pl.Trainer.from_argparse_args(args) - trainer.fit(model) + trainer.fit(model, datamodule=datamodule) if __name__ == '__main__': diff --git a/tests/models/self_supervised/test_models.py b/tests/models/self_supervised/test_models.py index b2df8af307..eb0d7401f8 100644 --- a/tests/models/self_supervised/test_models.py +++ b/tests/models/self_supervised/test_models.py @@ -63,9 +63,9 @@ def test_moco(tmpdir): datamodule.train_transforms = Moco2TrainCIFAR10Transforms() datamodule.val_transforms = Moco2EvalCIFAR10Transforms() - model = MocoV2(data_dir=tmpdir, batch_size=2, datamodule=datamodule, online_ft=True) + model = MocoV2(data_dir=tmpdir, batch_size=2, online_ft=True) trainer = pl.Trainer(fast_dev_run=True, max_epochs=1, default_root_dir=tmpdir, callbacks=[MocoLRScheduler()]) - trainer.fit(model) + trainer.fit(model, datamodule=datamodule) loss = trainer.progress_bar_dict['loss'] assert float(loss) > 0 diff --git a/tests/models/test_vision.py b/tests/models/test_vision.py index 73af207f1a..7b361af219 100644 --- a/tests/models/test_vision.py +++ b/tests/models/test_vision.py @@ -8,7 +8,7 @@ def test_igpt(tmpdir): pl.seed_everything(0) dm = MNISTDataModule(tmpdir, normalize=False) - model = ImageGPT(datamodule=dm) + model = ImageGPT() trainer = pl.Trainer( limit_train_batches=2, @@ -16,19 +16,19 @@ def test_igpt(tmpdir): limit_test_batches=2, max_epochs=1, ) - trainer.fit(model) - trainer.test() + trainer.fit(model, datamodule=dm) + trainer.test(datamodule=dm) assert trainer.callback_metrics["test_loss"] < 1.7 dm = FashionMNISTDataModule(tmpdir, num_workers=1) - model = ImageGPT(classify=True, datamodule=dm) + model = ImageGPT(classify=True) trainer = pl.Trainer( limit_train_batches=2, limit_val_batches=2, limit_test_batches=2, max_epochs=1, ) - trainer.fit(model) + trainer.fit(model, datamodule=dm) def test_gpt2(tmpdir): From 2ca4084459cd25234ada8a7fe5432afdb5dbdb40 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 15 Oct 2020 09:42:26 -0400 Subject: [PATCH 28/32] Update README.md --- README.md | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/README.md b/README.md index 28f398f7cc..e9c5462d16 100644 --- a/README.md +++ b/README.md @@ -177,3 +177,14 @@ with your implementation. Bolts is supported by the PyTorch Lightning team and the PyTorch Lightning community! +## Citation +To cite bolts use: + +``` +@article{falcon2020framework, + title={A Framework For Contrastive Self-Supervised Learning And Designing A New Approach}, + author={Falcon, William and Cho, Kyunghyun}, + journal={arXiv preprint arXiv:2009.00104}, + year={2020} +} +``` From d8b5b0f15c60ca4e908d1170885b344138fbf5f1 Mon Sep 17 00:00:00 2001 From: William Falcon Date: Thu, 15 Oct 2020 10:20:53 -0400 Subject: [PATCH 29/32] Update README.md --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index e9c5462d16..203f50f0f5 100644 --- a/README.md +++ b/README.md @@ -188,3 +188,5 @@ To cite bolts use: year={2020} } ``` + +To cite other contributed models or modules, please cite the authors directly (if they don't have bibtex, ping the authors on a GH issue) From f48357be353b7acdd882379ac3308fbec95dc40d Mon Sep 17 00:00:00 2001 From: Ananya Harsh Jha Date: Thu, 15 Oct 2020 14:47:07 -0400 Subject: [PATCH 30/32] added Attribution-NonCommercial 4.0 International to MOCO (#282) * added Attribution-NonCommercial 4.0 International to MOCO * added link to license in the module file --- pl_bolts/models/self_supervised/moco/LICENSE | 399 ++++++++++++++++++ .../self_supervised/moco/moco2_module.py | 5 + 2 files changed, 404 insertions(+) create mode 100644 pl_bolts/models/self_supervised/moco/LICENSE diff --git a/pl_bolts/models/self_supervised/moco/LICENSE b/pl_bolts/models/self_supervised/moco/LICENSE new file mode 100644 index 0000000000..50f2e656c8 --- /dev/null +++ b/pl_bolts/models/self_supervised/moco/LICENSE @@ -0,0 +1,399 @@ +Attribution-NonCommercial 4.0 International + +======================================================================= + +Creative Commons Corporation ("Creative Commons") is not a law firm and +does not provide legal services or legal advice. Distribution of +Creative Commons public licenses does not create a lawyer-client or +other relationship. Creative Commons makes its licenses and related +information available on an "as-is" basis. Creative Commons gives no +warranties regarding its licenses, any material licensed under their +terms and conditions, or any related information. Creative Commons +disclaims all liability for damages resulting from their use to the +fullest extent possible. + +Using Creative Commons Public Licenses + +Creative Commons public licenses provide a standard set of terms and +conditions that creators and other rights holders may use to share +original works of authorship and other material subject to copyright +and certain other rights specified in the public license below. The +following considerations are for informational purposes only, are not +exhaustive, and do not form part of our licenses. + + Considerations for licensors: Our public licenses are + intended for use by those authorized to give the public + permission to use material in ways otherwise restricted by + copyright and certain other rights. Our licenses are + irrevocable. Licensors should read and understand the terms + and conditions of the license they choose before applying it. + Licensors should also secure all rights necessary before + applying our licenses so that the public can reuse the + material as expected. Licensors should clearly mark any + material not subject to the license. This includes other CC- + licensed material, or material used under an exception or + limitation to copyright. More considerations for licensors: + wiki.creativecommons.org/Considerations_for_licensors + + Considerations for the public: By using one of our public + licenses, a licensor grants the public permission to use the + licensed material under specified terms and conditions. If + the licensor's permission is not necessary for any reason--for + example, because of any applicable exception or limitation to + copyright--then that use is not regulated by the license. Our + licenses grant only permissions under copyright and certain + other rights that a licensor has authority to grant. Use of + the licensed material may still be restricted for other + reasons, including because others have copyright or other + rights in the material. A licensor may make special requests, + such as asking that all changes be marked or described. + Although not required by our licenses, you are encouraged to + respect those requests where reasonable. More_considerations + for the public: + wiki.creativecommons.org/Considerations_for_licensees + +======================================================================= + +Creative Commons Attribution-NonCommercial 4.0 International Public +License + +By exercising the Licensed Rights (defined below), You accept and agree +to be bound by the terms and conditions of this Creative Commons +Attribution-NonCommercial 4.0 International Public License ("Public +License"). To the extent this Public License may be interpreted as a +contract, You are granted the Licensed Rights in consideration of Your +acceptance of these terms and conditions, and the Licensor grants You +such rights in consideration of benefits the Licensor receives from +making the Licensed Material available under these terms and +conditions. + +Section 1 -- Definitions. + + a. Adapted Material means material subject to Copyright and Similar + Rights that is derived from or based upon the Licensed Material + and in which the Licensed Material is translated, altered, + arranged, transformed, or otherwise modified in a manner requiring + permission under the Copyright and Similar Rights held by the + Licensor. For purposes of this Public License, where the Licensed + Material is a musical work, performance, or sound recording, + Adapted Material is always produced where the Licensed Material is + synched in timed relation with a moving image. + + b. Adapter's License means the license You apply to Your Copyright + and Similar Rights in Your contributions to Adapted Material in + accordance with the terms and conditions of this Public License. + + c. Copyright and Similar Rights means copyright and/or similar rights + closely related to copyright including, without limitation, + performance, broadcast, sound recording, and Sui Generis Database + Rights, without regard to how the rights are labeled or + categorized. For purposes of this Public License, the rights + specified in Section 2(b)(1)-(2) are not Copyright and Similar + Rights. + d. Effective Technological Measures means those measures that, in the + absence of proper authority, may not be circumvented under laws + fulfilling obligations under Article 11 of the WIPO Copyright + Treaty adopted on December 20, 1996, and/or similar international + agreements. + + e. Exceptions and Limitations means fair use, fair dealing, and/or + any other exception or limitation to Copyright and Similar Rights + that applies to Your use of the Licensed Material. + + f. Licensed Material means the artistic or literary work, database, + or other material to which the Licensor applied this Public + License. + + g. Licensed Rights means the rights granted to You subject to the + terms and conditions of this Public License, which are limited to + all Copyright and Similar Rights that apply to Your use of the + Licensed Material and that the Licensor has authority to license. + + h. Licensor means the individual(s) or entity(ies) granting rights + under this Public License. + + i. NonCommercial means not primarily intended for or directed towards + commercial advantage or monetary compensation. For purposes of + this Public License, the exchange of the Licensed Material for + other material subject to Copyright and Similar Rights by digital + file-sharing or similar means is NonCommercial provided there is + no payment of monetary compensation in connection with the + exchange. + + j. Share means to provide material to the public by any means or + process that requires permission under the Licensed Rights, such + as reproduction, public display, public performance, distribution, + dissemination, communication, or importation, and to make material + available to the public including in ways that members of the + public may access the material from a place and at a time + individually chosen by them. + + k. Sui Generis Database Rights means rights other than copyright + resulting from Directive 96/9/EC of the European Parliament and of + the Council of 11 March 1996 on the legal protection of databases, + as amended and/or succeeded, as well as other essentially + equivalent rights anywhere in the world. + + l. You means the individual or entity exercising the Licensed Rights + under this Public License. Your has a corresponding meaning. + +Section 2 -- Scope. + + a. License grant. + + 1. Subject to the terms and conditions of this Public License, + the Licensor hereby grants You a worldwide, royalty-free, + non-sublicensable, non-exclusive, irrevocable license to + exercise the Licensed Rights in the Licensed Material to: + + a. reproduce and Share the Licensed Material, in whole or + in part, for NonCommercial purposes only; and + + b. produce, reproduce, and Share Adapted Material for + NonCommercial purposes only. + + 2. Exceptions and Limitations. For the avoidance of doubt, where + Exceptions and Limitations apply to Your use, this Public + License does not apply, and You do not need to comply with + its terms and conditions. + + 3. Term. The term of this Public License is specified in Section + 6(a). + + 4. Media and formats; technical modifications allowed. The + Licensor authorizes You to exercise the Licensed Rights in + all media and formats whether now known or hereafter created, + and to make technical modifications necessary to do so. The + Licensor waives and/or agrees not to assert any right or + authority to forbid You from making technical modifications + necessary to exercise the Licensed Rights, including + technical modifications necessary to circumvent Effective + Technological Measures. For purposes of this Public License, + simply making modifications authorized by this Section 2(a) + (4) never produces Adapted Material. + + 5. Downstream recipients. + + a. Offer from the Licensor -- Licensed Material. Every + recipient of the Licensed Material automatically + receives an offer from the Licensor to exercise the + Licensed Rights under the terms and conditions of this + Public License. + + b. No downstream restrictions. You may not offer or impose + any additional or different terms or conditions on, or + apply any Effective Technological Measures to, the + Licensed Material if doing so restricts exercise of the + Licensed Rights by any recipient of the Licensed + Material. + + 6. No endorsement. Nothing in this Public License constitutes or + may be construed as permission to assert or imply that You + are, or that Your use of the Licensed Material is, connected + with, or sponsored, endorsed, or granted official status by, + the Licensor or others designated to receive attribution as + provided in Section 3(a)(1)(A)(i). + + b. Other rights. + + 1. Moral rights, such as the right of integrity, are not + licensed under this Public License, nor are publicity, + privacy, and/or other similar personality rights; however, to + the extent possible, the Licensor waives and/or agrees not to + assert any such rights held by the Licensor to the limited + extent necessary to allow You to exercise the Licensed + Rights, but not otherwise. + + 2. Patent and trademark rights are not licensed under this + Public License. + + 3. To the extent possible, the Licensor waives any right to + collect royalties from You for the exercise of the Licensed + Rights, whether directly or through a collecting society + under any voluntary or waivable statutory or compulsory + licensing scheme. In all other cases the Licensor expressly + reserves any right to collect such royalties, including when + the Licensed Material is used other than for NonCommercial + purposes. + +Section 3 -- License Conditions. + +Your exercise of the Licensed Rights is expressly made subject to the +following conditions. + + a. Attribution. + + 1. If You Share the Licensed Material (including in modified + form), You must: + + a. retain the following if it is supplied by the Licensor + with the Licensed Material: + + i. identification of the creator(s) of the Licensed + Material and any others designated to receive + attribution, in any reasonable manner requested by + the Licensor (including by pseudonym if + designated); + + ii. a copyright notice; + + iii. a notice that refers to this Public License; + + iv. a notice that refers to the disclaimer of + warranties; + + v. a URI or hyperlink to the Licensed Material to the + extent reasonably practicable; + + b. indicate if You modified the Licensed Material and + retain an indication of any previous modifications; and + + c. indicate the Licensed Material is licensed under this + Public License, and include the text of, or the URI or + hyperlink to, this Public License. + + 2. You may satisfy the conditions in Section 3(a)(1) in any + reasonable manner based on the medium, means, and context in + which You Share the Licensed Material. For example, it may be + reasonable to satisfy the conditions by providing a URI or + hyperlink to a resource that includes the required + information. + + 3. If requested by the Licensor, You must remove any of the + information required by Section 3(a)(1)(A) to the extent + reasonably practicable. + + 4. If You Share Adapted Material You produce, the Adapter's + License You apply must not prevent recipients of the Adapted + Material from complying with this Public License. + +Section 4 -- Sui Generis Database Rights. + +Where the Licensed Rights include Sui Generis Database Rights that +apply to Your use of the Licensed Material: + + a. for the avoidance of doubt, Section 2(a)(1) grants You the right + to extract, reuse, reproduce, and Share all or a substantial + portion of the contents of the database for NonCommercial purposes + only; + + b. if You include all or a substantial portion of the database + contents in a database in which You have Sui Generis Database + Rights, then the database in which You have Sui Generis Database + Rights (but not its individual contents) is Adapted Material; and + + c. You must comply with the conditions in Section 3(a) if You Share + all or a substantial portion of the contents of the database. + +For the avoidance of doubt, this Section 4 supplements and does not +replace Your obligations under this Public License where the Licensed +Rights include other Copyright and Similar Rights. + +Section 5 -- Disclaimer of Warranties and Limitation of Liability. + + a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE + EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS + AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF + ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, + IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, + WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR + PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, + ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT + KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT + ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. + + b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE + TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, + NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, + INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, + COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR + USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN + ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR + DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR + IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. + + c. The disclaimer of warranties and limitation of liability provided + above shall be interpreted in a manner that, to the extent + possible, most closely approximates an absolute disclaimer and + waiver of all liability. + +Section 6 -- Term and Termination. + + a. This Public License applies for the term of the Copyright and + Similar Rights licensed here. However, if You fail to comply with + this Public License, then Your rights under this Public License + terminate automatically. + + b. Where Your right to use the Licensed Material has terminated under + Section 6(a), it reinstates: + + 1. automatically as of the date the violation is cured, provided + it is cured within 30 days of Your discovery of the + violation; or + + 2. upon express reinstatement by the Licensor. + + For the avoidance of doubt, this Section 6(b) does not affect any + right the Licensor may have to seek remedies for Your violations + of this Public License. + + c. For the avoidance of doubt, the Licensor may also offer the + Licensed Material under separate terms or conditions or stop + distributing the Licensed Material at any time; however, doing so + will not terminate this Public License. + + d. Sections 1, 5, 6, 7, and 8 survive termination of this Public + License. + +Section 7 -- Other Terms and Conditions. + + a. The Licensor shall not be bound by any additional or different + terms or conditions communicated by You unless expressly agreed. + + b. Any arrangements, understandings, or agreements regarding the + Licensed Material not stated herein are separate from and + independent of the terms and conditions of this Public License. + +Section 8 -- Interpretation. + + a. For the avoidance of doubt, this Public License does not, and + shall not be interpreted to, reduce, limit, restrict, or impose + conditions on any use of the Licensed Material that could lawfully + be made without permission under this Public License. + + b. To the extent possible, if any provision of this Public License is + deemed unenforceable, it shall be automatically reformed to the + minimum extent necessary to make it enforceable. If the provision + cannot be reformed, it shall be severed from this Public License + without affecting the enforceability of the remaining terms and + conditions. + + c. No term or condition of this Public License will be waived and no + failure to comply consented to unless expressly agreed to by the + Licensor. + + d. Nothing in this Public License constitutes or may be interpreted + as a limitation upon, or waiver of, any privileges and immunities + that apply to the Licensor or You, including from the legal + processes of any jurisdiction or authority. + +======================================================================= + +Creative Commons is not a party to its public +licenses. Notwithstanding, Creative Commons may elect to apply one of +its public licenses to material it publishes and in those instances +will be considered the “Licensor.” The text of the Creative Commons +public licenses is dedicated to the public domain under the CC0 Public +Domain Dedication. Except for the limited purpose of indicating that +material is shared under a Creative Commons public license or as +otherwise permitted by the Creative Commons policies published at +creativecommons.org/policies, Creative Commons does not authorize the +use of the trademark "Creative Commons" or any other trademark or logo +of Creative Commons without its prior written consent including, +without limitation, in connection with any unauthorized modifications +to any of its public licenses or any other arrangements, +understandings, or agreements concerning use of licensed material. For +the avoidance of doubt, this paragraph does not form part of the +public licenses. + +Creative Commons may be contacted at creativecommons.org. \ No newline at end of file diff --git a/pl_bolts/models/self_supervised/moco/moco2_module.py b/pl_bolts/models/self_supervised/moco/moco2_module.py index bdc7247717..03b50052b0 100644 --- a/pl_bolts/models/self_supervised/moco/moco2_module.py +++ b/pl_bolts/models/self_supervised/moco/moco2_module.py @@ -3,6 +3,11 @@ Original work is: Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved This implementation is: Copyright (c) PyTorch Lightning, Inc. and its affiliates. All Rights Reserved + +This implementation is licensed under Attribution-NonCommercial 4.0 International; +You may not use this file except in compliance with the License. + +You may obtain a copy of the License from the LICENSE file present in this folder. """ from argparse import ArgumentParser From 3187ebeccb61989f9adcb09c9f8bcdc973c61fb8 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Mon, 19 Oct 2020 01:25:43 +0900 Subject: [PATCH 31/32] Use `Optional` for arguments set to `None` by default (#283) * Create branch refactor/use-optional * Use `Optional` for variables set to `None` by default * Format to keep line length <=120 --- pl_bolts/callbacks/printing.py | 14 +++++++------- pl_bolts/callbacks/self_supervised.py | 14 +++++++++----- pl_bolts/datamodules/cifar10_datamodule.py | 2 +- pl_bolts/datamodules/imagenet_datamodule.py | 3 ++- pl_bolts/datamodules/stl10_datamodule.py | 3 ++- pl_bolts/datasets/cifar10_dataset.py | 4 ++-- pl_bolts/datasets/ssl_amdim_datasets.py | 8 ++++---- pl_bolts/models/rl/dqn_model.py | 4 ++-- pl_bolts/models/self_supervised/amdim/datasets.py | 5 +++-- pl_bolts/models/self_supervised/cpc/cpc_module.py | 5 +++-- 10 files changed, 35 insertions(+), 27 deletions(-) diff --git a/pl_bolts/callbacks/printing.py b/pl_bolts/callbacks/printing.py index d2364b4085..53cccf0f4f 100644 --- a/pl_bolts/callbacks/printing.py +++ b/pl_bolts/callbacks/printing.py @@ -1,6 +1,6 @@ import copy from itertools import zip_longest -from typing import List, Any, Dict, Callable +from typing import List, Any, Dict, Callable, Optional from pytorch_lightning.callbacks import Callback from pytorch_lightning.utilities import rank_zero_info @@ -43,13 +43,13 @@ def on_epoch_end(self, trainer, pl_module): def dicts_to_table(dicts: List[Dict], - keys: List[str] = None, - pads: List[str] = None, - fcodes: List[str] = None, - convert_headers: Dict[str, Callable] = None, - header_names: List[str] = None, + keys: Optional[List[str]] = None, + pads: Optional[List[str]] = None, + fcodes: Optional[List[str]] = None, + convert_headers: Optional[Dict[str, Callable]] = None, + header_names: Optional[List[str]] = None, skip_none_lines: bool = False, - replace_values: Dict[str, Any] = None): + replace_values: Optional[Dict[str, Any]] = None): """ Generate ascii table from dictionary Taken from (https://stackoverflow.com/questions/40056747/print-a-list-of-dictionaries-in-table-form) diff --git a/pl_bolts/callbacks/self_supervised.py b/pl_bolts/callbacks/self_supervised.py index 63c17e6d5d..c9f00fca13 100644 --- a/pl_bolts/callbacks/self_supervised.py +++ b/pl_bolts/callbacks/self_supervised.py @@ -1,4 +1,5 @@ import math +from typing import Optional import pytorch_lightning as pl import torch @@ -20,7 +21,13 @@ class SSLOnlineEvaluator(pl.Callback): # pragma: no-cover model.num_classes = ... # the num of classes in the model """ - def __init__(self, drop_p: float = 0.2, hidden_dim: int = 1024, z_dim: int = None, num_classes: int = None): + def __init__( + self, + drop_p: float = 0.2, + hidden_dim: int = 1024, + z_dim: Optional[int] = None, + num_classes: Optional[int] = None, + ): """ Args: drop_p: (0.2) dropout probability @@ -44,10 +51,7 @@ def on_pretrain_routine_start(self, trainer, pl_module): self.num_classes = pl_module.num_classes pl_module.non_linear_evaluator = SSLEvaluator( - n_input=self.z_dim, - n_classes=self.num_classes, - p=self.drop_p, - n_hidden=self.hidden_dim + n_input=self.z_dim, n_classes=self.num_classes, p=self.drop_p, n_hidden=self.hidden_dim ).to(pl_module.device) self.optimizer = torch.optim.SGD(pl_module.non_linear_evaluator.parameters(), lr=1e-3) diff --git a/pl_bolts/datamodules/cifar10_datamodule.py b/pl_bolts/datamodules/cifar10_datamodule.py index b6e80ef53a..29713cb68c 100644 --- a/pl_bolts/datamodules/cifar10_datamodule.py +++ b/pl_bolts/datamodules/cifar10_datamodule.py @@ -67,7 +67,7 @@ class CIFAR10DataModule(LightningDataModule): def __init__( self, - data_dir: str = None, + data_dir: Optional[str] = None, val_split: int = 5000, num_workers: int = 16, batch_size: int = 32, diff --git a/pl_bolts/datamodules/imagenet_datamodule.py b/pl_bolts/datamodules/imagenet_datamodule.py index 4ab505989e..588c57e46a 100644 --- a/pl_bolts/datamodules/imagenet_datamodule.py +++ b/pl_bolts/datamodules/imagenet_datamodule.py @@ -1,4 +1,5 @@ import os +from typing import Optional from warnings import warn from pytorch_lightning import LightningDataModule @@ -52,7 +53,7 @@ class ImagenetDataModule(LightningDataModule): def __init__( self, data_dir: str, - meta_dir: str = None, + meta_dir: Optional[str] = None, num_imgs_per_val_class: int = 50, image_size: int = 224, num_workers: int = 16, diff --git a/pl_bolts/datamodules/stl10_datamodule.py b/pl_bolts/datamodules/stl10_datamodule.py index df49164579..a6eda48c7e 100644 --- a/pl_bolts/datamodules/stl10_datamodule.py +++ b/pl_bolts/datamodules/stl10_datamodule.py @@ -1,4 +1,5 @@ import os +from typing import Optional from warnings import warn import torch @@ -56,7 +57,7 @@ class STL10DataModule(LightningDataModule): # pragma: no cover def __init__( self, - data_dir: str = None, + data_dir: Optional[str] = None, unlabeled_val_split: int = 5000, train_val_split: int = 500, num_workers: int = 16, diff --git a/pl_bolts/datasets/cifar10_dataset.py b/pl_bolts/datasets/cifar10_dataset.py index 51af47530e..cd6e50ddef 100644 --- a/pl_bolts/datasets/cifar10_dataset.py +++ b/pl_bolts/datasets/cifar10_dataset.py @@ -79,7 +79,7 @@ def __init__( self, data_dir: str = '.', train: bool = True, - transform: Callable = None, + transform: Optional[Callable] = None, download: bool = True ): super().__init__() @@ -182,7 +182,7 @@ def __init__( self, data_dir: str = '.', train: bool = True, - transform: Callable = None, + transform: Optional[Callable] = None, download: bool = False, num_samples: int = 100, labels: Optional[Sequence] = (1, 5, 8), diff --git a/pl_bolts/datasets/ssl_amdim_datasets.py b/pl_bolts/datasets/ssl_amdim_datasets.py index d9bb17942b..ca465c07ea 100644 --- a/pl_bolts/datasets/ssl_amdim_datasets.py +++ b/pl_bolts/datasets/ssl_amdim_datasets.py @@ -1,5 +1,5 @@ from abc import ABC -from typing import Callable +from typing import Callable, Optional from warnings import warn import numpy as np @@ -100,10 +100,10 @@ def __init__( self, root: str, split: str = 'val', - transform: Callable = None, - target_transform: Callable = None, + transform: Optional[Callable] = None, + target_transform: Optional[Callable] = None, download: bool = False, - nb_labeled_per_class: int = None, + nb_labeled_per_class: Optional[int] = None, val_pct: float = 0.10 ): diff --git a/pl_bolts/models/rl/dqn_model.py b/pl_bolts/models/rl/dqn_model.py index f7ab1fc40b..1d6110d2de 100644 --- a/pl_bolts/models/rl/dqn_model.py +++ b/pl_bolts/models/rl/dqn_model.py @@ -4,7 +4,7 @@ import argparse from collections import OrderedDict -from typing import Tuple, List, Dict +from typing import Tuple, List, Dict, Optional from warnings import warn import numpy as np @@ -343,7 +343,7 @@ def test_dataloader(self) -> DataLoader: return self._dataloader() @staticmethod - def make_environment(env_name: str, seed: int = None) -> gym.Env: + def make_environment(env_name: str, seed: Optional[int] = None) -> gym.Env: """ Initialise gym environment Args: diff --git a/pl_bolts/models/self_supervised/amdim/datasets.py b/pl_bolts/models/self_supervised/amdim/datasets.py index a4bc65e60b..b610d8cf0d 100644 --- a/pl_bolts/models/self_supervised/amdim/datasets.py +++ b/pl_bolts/models/self_supervised/amdim/datasets.py @@ -1,4 +1,5 @@ from warnings import warn +from typing import Optional from torch.utils.data import random_split @@ -52,7 +53,7 @@ def imagenet(dataset_root, nb_classes, split: str = 'train'): return dataset @staticmethod - def stl(dataset_root, split: str = None): + def stl(dataset_root, split: Optional[str] = None): dataset = STL10( root=dataset_root, split='unlabeled', @@ -93,7 +94,7 @@ def cifar10(dataset_root, patch_size, patch_overlap, split: str = 'train'): return dataset @staticmethod - def stl(dataset_root, patch_size, patch_overlap, split: str = None): + def stl(dataset_root, patch_size, patch_overlap, split: Optional[str] = None): train_transform = amdim_transforms.TransformsSTL10Patches( patch_size=patch_size, overlap=patch_overlap diff --git a/pl_bolts/models/self_supervised/cpc/cpc_module.py b/pl_bolts/models/self_supervised/cpc/cpc_module.py index 958f01978a..d173eebab0 100644 --- a/pl_bolts/models/self_supervised/cpc/cpc_module.py +++ b/pl_bolts/models/self_supervised/cpc/cpc_module.py @@ -4,6 +4,7 @@ """ import math from argparse import ArgumentParser +from typing import Optional import pytorch_lightning as pl import torch @@ -33,7 +34,7 @@ class CPCV2(pl.LightningModule): def __init__( self, - datamodule: pl.LightningDataModule = None, + datamodule: Optional[pl.LightningDataModule] = None, encoder_name: str = 'cpc_encoder', patch_size: int = 8, patch_overlap: int = 4, @@ -43,7 +44,7 @@ def __init__( learning_rate: int = 1e-4, data_dir: str = '', batch_size: int = 32, - pretrained: str = None, + pretrained: Optional[str] = None, **kwargs, ): """ From d32d3eb12809c0fe15aa31d67ef06969837fc7a9 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Mon, 19 Oct 2020 17:06:21 +0900 Subject: [PATCH 32/32] docs: Add empty lines before `Args:` in docstring (#284) * Add empty lines * Add empty lines above Args: --- pl_bolts/callbacks/self_supervised.py | 1 + pl_bolts/datamodules/experience_source.py | 12 +++++++++++ pl_bolts/datamodules/kitti_datamodule.py | 2 +- pl_bolts/losses/rl.py | 6 ++++++ pl_bolts/models/rl/common/agents.py | 21 +++++++++++++------ pl_bolts/models/rl/common/memory.py | 15 +++++++++++++ pl_bolts/models/rl/common/networks.py | 20 ++++++++++++++++++ pl_bolts/models/rl/double_dqn_model.py | 2 ++ pl_bolts/models/rl/dqn_model.py | 9 ++++++++ pl_bolts/models/rl/per_dqn_model.py | 2 ++ pl_bolts/models/rl/reinforce_model.py | 11 ++++++++++ .../rl/vanilla_policy_gradient_model.py | 7 +++++++ pl_bolts/models/self_supervised/resnets.py | 10 +++++++++ 13 files changed, 111 insertions(+), 7 deletions(-) diff --git a/pl_bolts/callbacks/self_supervised.py b/pl_bolts/callbacks/self_supervised.py index c9f00fca13..6babac6f39 100644 --- a/pl_bolts/callbacks/self_supervised.py +++ b/pl_bolts/callbacks/self_supervised.py @@ -59,6 +59,7 @@ def on_pretrain_routine_start(self, trainer, pl_module): def get_representations(self, pl_module, x): """ Override this to customize for the particular model + Args: pl_module: x: diff --git a/pl_bolts/datamodules/experience_source.py b/pl_bolts/datamodules/experience_source.py index bfdd66d529..e4df09c85d 100644 --- a/pl_bolts/datamodules/experience_source.py +++ b/pl_bolts/datamodules/experience_source.py @@ -83,8 +83,10 @@ def __init__(self, env, agent, n_steps: int = 1) -> None: def runner(self, device: torch.device) -> Tuple[Experience]: """Experience Source iterator yielding Tuple of experiences for n_steps. These come from the pool of environments provided by the user. + Args: device: current device to be used for executing experience steps + Returns: Tuple of Experiences """ @@ -113,6 +115,7 @@ def update_history_queue(self, env_idx, exp, history) -> None: Updates the experience history queue with the lastest experiences. In the event of an experience step is in the done state, the history will be incrementally appended to the queue, removing the tail of the history each time. + Args: env_idx: index of the environment exp: the current experience @@ -172,10 +175,12 @@ def env_actions(self, device) -> List[List[int]]: def env_step(self, env_idx: int, env: Env, action: List[int]) -> Experience: """ Carries out a step through the given environment using the given action + Args: env_idx: index of the current environment env: env at index env_idx action: action for this environment step + Returns: Experience tuple """ @@ -192,6 +197,7 @@ def update_env_stats(self, env_idx: int) -> None: """ To be called at the end of the history tail generation during the termination state. Updates the stats tracked for all environments + Args: env_idx: index of the environment used to update stats """ @@ -238,8 +244,10 @@ def __init__(self, env: Env, agent, n_steps: int = 1, gamma: float = 0.99): def runner(self, device: torch.device) -> Experience: """ Iterates through experience tuple and calculate discounted experience + Args: device: current device to be used for executing experience steps + Yields: Discounted Experience """ @@ -255,8 +263,10 @@ def split_head_tail_exp(self, experiences: Tuple[Experience]) -> Tuple[List, Tup """ Takes in a tuple of experiences and returns the last state and tail experiences based on if the last state is the end of an episode + Args: experiences: Tuple of N Experience + Returns: last state (Array or None) and remaining Experience """ @@ -271,8 +281,10 @@ def split_head_tail_exp(self, experiences: Tuple[Experience]) -> Tuple[List, Tup def discount_rewards(self, experiences: Tuple[Experience]) -> float: """ Calculates the discounted reward over N experiences + Args: experiences: Tuple of Experience + Returns: total discounted reward """ diff --git a/pl_bolts/datamodules/kitti_datamodule.py b/pl_bolts/datamodules/kitti_datamodule.py index 127fb2e26f..03348638e1 100644 --- a/pl_bolts/datamodules/kitti_datamodule.py +++ b/pl_bolts/datamodules/kitti_datamodule.py @@ -46,7 +46,7 @@ def __init__( Trainer().fit(model, dm) - Args:: + Args: data_dir: where to load the data from path, i.e. '/path/to/folder/with/data_semantics/' val_split: size of validation test (default 0.2) test_split: size of test set (default 0.1) diff --git a/pl_bolts/losses/rl.py b/pl_bolts/losses/rl.py index a4a974f7c6..12350648c2 100644 --- a/pl_bolts/losses/rl.py +++ b/pl_bolts/losses/rl.py @@ -13,11 +13,13 @@ def dqn_loss(batch: Tuple[torch.Tensor, torch.Tensor], net: nn.Module, target_net: nn.Module, gamma: float = 0.99) -> torch.Tensor: """ Calculates the mse loss using a mini batch from the replay buffer + Args: batch: current mini batch of replay data net: main training network target_net: target network of the main training network gamma: discount factor + Returns: loss """ @@ -45,11 +47,13 @@ def double_dqn_loss(batch: Tuple[torch.Tensor, torch.Tensor], net: nn.Module, Calculates the mse loss using a mini batch from the replay buffer. This uses an improvement to the original DQN loss by using the double dqn. This is shown by using the actions of the train network to pick the value from the target network. This code is heavily commented in order to explain the process clearly + Args: batch: current mini batch of replay data net: main training network target_net: target network of the main training network gamma: discount factor + Returns: loss """ @@ -89,12 +93,14 @@ def per_dqn_loss(batch: Tuple[torch.Tensor, torch.Tensor], batch_weights: List, target_net: nn.Module, gamma: float = 0.99) -> Tuple[torch.Tensor, np.ndarray]: """ Calculates the mse loss with the priority weights of the batch from the PER buffer + Args: batch: current mini batch of replay data batch_weights: how each of these samples are weighted in terms of priority net: main training network target_net: target network of the main training network gamma: discount factor + Returns: loss and batch_weights """ diff --git a/pl_bolts/models/rl/common/agents.py b/pl_bolts/models/rl/common/agents.py index 92c5fbb8fa..d9f4d9d063 100644 --- a/pl_bolts/models/rl/common/agents.py +++ b/pl_bolts/models/rl/common/agents.py @@ -20,9 +20,11 @@ def __init__(self, net: nn.Module): def __call__(self, state: torch.Tensor, device: str, *args, **kwargs) -> List[int]: """ Using the given network, decide what action to carry + Args: state: current state of the environment device: device used for current batch + Returns: action """ @@ -51,9 +53,11 @@ def __init__( def __call__(self, state: torch.Tensor, device: str) -> List[int]: """ Takes in the current state and returns the action based on the agents policy + Args: state: current state of the environment device: the device used for the current batch + Returns: action defined by policy """ @@ -79,12 +83,14 @@ def get_random_action(self, state: torch.Tensor) -> int: def get_action(self, state: torch.Tensor, device: torch.device): """ - Returns the best action based on the Q values of the network - Args: - state: current state of the environment - device: the device used for the current batch - Returns: - action defined by Q values + Returns the best action based on the Q values of the network + + Args: + state: current state of the environment + device: the device used for the current batch + + Returns: + action defined by Q values """ if not isinstance(state, torch.Tensor): state = torch.tensor(state, device=device) @@ -96,6 +102,7 @@ def get_action(self, state: torch.Tensor, device: torch.device): def update_epsilon(self, step: int) -> None: """ Updates the epsilon value based on the current step + Args: step: current global step """ @@ -109,9 +116,11 @@ class PolicyAgent(Agent): def __call__(self, states: torch.Tensor, device: str) -> List[int]: """ Takes in the current state and returns the action based on the agents policy + Args: states: current state of the environment device: the device used for the current batch + Returns: action defined by policy """ diff --git a/pl_bolts/models/rl/common/memory.py b/pl_bolts/models/rl/common/memory.py index 2bc704ee4d..7ecb5a7f86 100644 --- a/pl_bolts/models/rl/common/memory.py +++ b/pl_bolts/models/rl/common/memory.py @@ -30,6 +30,7 @@ def __len__(self) -> None: def append(self, experience: Experience) -> None: """ Add experience to the buffer + Args: experience: tuple (state, action, reward, done, new_state) """ @@ -65,8 +66,10 @@ class ReplayBuffer(Buffer): def sample(self, batch_size: int) -> Tuple: """ Takes a sample of the buffer + Args: batch_size: current batch_size + Returns: a batch of tuple np arrays of state, action, reward, done, next_state """ @@ -107,6 +110,7 @@ def __init__(self, capacity: int, n_steps: int = 1, gamma: float = 0.99) -> None def append(self, exp: Experience) -> None: """ Add experience to the buffer + Args: exp: tuple (state, action, reward, done, new_state) """ @@ -128,6 +132,7 @@ def update_history_queue(self, exp) -> None: Updates the experience history queue with the lastest experiences. In the event of an experience step is in the done state, the history will be incrementally appended to the queue, removing the tail of the history each time. + Args: env_idx: index of the environment exp: the current experience @@ -161,8 +166,10 @@ def split_head_tail_exp(self, experiences: Tuple[Experience]) -> Tuple[List, Tup """ Takes in a tuple of experiences and returns the last state and tail experiences based on if the last state is the end of an episode + Args: experiences: Tuple of N Experience + Returns: last state (Array or None) and remaining Experience """ @@ -177,8 +184,10 @@ def split_head_tail_exp(self, experiences: Tuple[Experience]) -> Tuple[List, Tup def discount_rewards(self, experiences: Tuple[Experience]) -> float: """ Calculates the discounted reward over N experiences + Args: experiences: Tuple of Experience + Returns: total discounted reward """ @@ -233,8 +242,10 @@ def __init__(self, buffer_size, prob_alpha=0.6, beta_start=0.4, beta_frames=1000 def update_beta(self, step) -> float: """ Update the beta value which accounts for the bias in the PER + Args: step: current global step + Returns: beta value for this indexed experience """ @@ -246,6 +257,7 @@ def update_beta(self, step) -> float: def append(self, exp) -> None: """ Adds experiences from exp_source to the PER buffer + Args: exp: experience tuple being added to the buffer """ @@ -266,8 +278,10 @@ def append(self, exp) -> None: def sample(self, batch_size=32) -> Tuple: """ Takes a prioritized sample from the buffer + Args: batch_size: size of sample + Returns: sample of experiences chosen with ranked probability """ @@ -308,6 +322,7 @@ def update_priorities(self, batch_indices: List, batch_priorities: List) -> None """ Update the priorities from the last batch, this should be called after the loss for this batch has been calculated. + Args: batch_indices: index of each datum in the batch batch_priorities: priority of each datum in the batch diff --git a/pl_bolts/models/rl/common/networks.py b/pl_bolts/models/rl/common/networks.py index 1f7398055e..9c8a0c7431 100644 --- a/pl_bolts/models/rl/common/networks.py +++ b/pl_bolts/models/rl/common/networks.py @@ -41,6 +41,7 @@ def __init__(self, input_shape, n_actions): def _get_conv_out(self, shape) -> int: """ Calculates the output size of the last conv layer + Args: shape: input dimensions Returns: @@ -52,6 +53,7 @@ def _get_conv_out(self, shape) -> int: def forward(self, input_x) -> Tensor: """ Forward pass through network + Args: x: input to network Returns: @@ -83,8 +85,10 @@ def __init__(self, input_shape: Tuple, n_actions: int, hidden_size: int = 128): def forward(self, input_x): """ Forward pass through network + Args: x: input to network + Returns: output of network """ @@ -123,8 +127,10 @@ def __init__(self, input_shape: Tuple, n_actions: int, hidden_size: int = 128): def forward(self, input_x): """ Forward pass through network. Calculates the Q using the value and advantage + Args: x: input to network + Returns: Q value """ @@ -136,8 +142,10 @@ def adv_val(self, input_x) -> Tuple[Tensor, Tensor]: """ Gets the advantage and value by passing out of the base network through the value and advantage heads + Args: input_x: input to network + Returns: advantage, value """ @@ -184,8 +192,10 @@ def __init__(self, input_shape: Tuple, n_actions: int, _: int = 128): def _get_conv_out(self, shape) -> int: """ Calculates the output size of the last conv layer + Args: shape: input dimensions + Returns: size of the conv output """ @@ -195,8 +205,10 @@ def _get_conv_out(self, shape) -> int: def forward(self, input_x): """ Forward pass through network. Calculates the Q using the value and advantage + Args: input_x: input to network + Returns: Q value """ @@ -208,8 +220,10 @@ def adv_val(self, input_x): """ Gets the advantage and value by passing out of the base network through the value and advantage heads + Args: input_x: input to network + Returns: advantage, value """ @@ -248,8 +262,10 @@ def __init__(self, input_shape, n_actions): def _get_conv_out(self, shape) -> int: """ Calculates the output size of the last conv layer + Args: shape: input dimensions + Returns: size of the conv output """ @@ -259,8 +275,10 @@ def _get_conv_out(self, shape) -> int: def forward(self, input_x) -> Tensor: """ Forward pass through network + Args: x: input to network + Returns: output of network """ @@ -312,8 +330,10 @@ def reset_parameters(self) -> None: def forward(self, input_x: Tensor) -> Tensor: """ Forward pass of the layer + Args: input_x: input tensor + Returns: output of the layer """ diff --git a/pl_bolts/models/rl/double_dqn_model.py b/pl_bolts/models/rl/double_dqn_model.py index f31ae16c6d..f74f670b60 100644 --- a/pl_bolts/models/rl/double_dqn_model.py +++ b/pl_bolts/models/rl/double_dqn_model.py @@ -61,9 +61,11 @@ def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], _) -> OrderedD """ Carries out a single step through the environment to update the replay buffer. Then calculates loss based on the minibatch recieved + Args: batch: current mini batch of replay data _: batch number, not used + Returns: Training loss and log metrics """ diff --git a/pl_bolts/models/rl/dqn_model.py b/pl_bolts/models/rl/dqn_model.py index 1d6110d2de..24f7089397 100644 --- a/pl_bolts/models/rl/dqn_model.py +++ b/pl_bolts/models/rl/dqn_model.py @@ -161,6 +161,7 @@ def __init__( def run_n_episodes(self, env, n_epsiodes: int = 1, epsilon: float = 1.0) -> List[int]: """ Carries out N episodes of the environment with the current agent + Args: env: environment to use, either train environment or test environment n_epsiodes: number of episodes to run @@ -208,8 +209,10 @@ def build_networks(self) -> None: def forward(self, x: torch.Tensor) -> torch.Tensor: """ Passes in a state x through the network and gets the q_values of each action as an output + Args: x: environment state + Returns: q values """ @@ -266,9 +269,11 @@ def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], _) -> OrderedD """ Carries out a single step through the environment to update the replay buffer. Then calculates loss based on the minibatch recieved + Args: batch: current mini batch of replay data _: batch number, not used + Returns: Training loss and log metrics """ @@ -346,9 +351,11 @@ def test_dataloader(self) -> DataLoader: def make_environment(env_name: str, seed: Optional[int] = None) -> gym.Env: """ Initialise gym environment + Args: env_name: environment name or tag seed: value to seed the environment RNG for reproducibility + Returns: gym environment """ @@ -365,7 +372,9 @@ def add_model_specific_args( ) -> argparse.ArgumentParser: """ Adds arguments for DQN model + Note: these params are fine tuned for Pong env + Args: arg_parser: parent parser """ diff --git a/pl_bolts/models/rl/per_dqn_model.py b/pl_bolts/models/rl/per_dqn_model.py index 9ffbae3cdd..fc8787a67a 100644 --- a/pl_bolts/models/rl/per_dqn_model.py +++ b/pl_bolts/models/rl/per_dqn_model.py @@ -121,9 +121,11 @@ def training_step(self, batch, _) -> OrderedDict: """ Carries out a single step through the environment to update the replay buffer. Then calculates loss based on the minibatch recieved + Args: batch: current mini batch of replay data _: batch number, not used + Returns: Training loss and log metrics """ diff --git a/pl_bolts/models/rl/reinforce_model.py b/pl_bolts/models/rl/reinforce_model.py index 78a5fadd43..897050d06a 100644 --- a/pl_bolts/models/rl/reinforce_model.py +++ b/pl_bolts/models/rl/reinforce_model.py @@ -119,8 +119,10 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: """ Passes in a state x through the network and gets the q_values of each action as an output + Args: x: environment state + Returns: q values """ @@ -129,8 +131,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: def calc_qvals(self, rewards: List[float]) -> List[float]: """Calculate the discounted rewards of all rewards in list + Args: rewards: list of rewards from latest batch + Returns: list of discounted rewards """ @@ -148,8 +152,10 @@ def calc_qvals(self, rewards: List[float]) -> List[float]: def discount_rewards(self, experiences: Tuple[Experience]) -> float: """ Calculates the discounted reward over N experiences + Args: experiences: Tuple of Experience + Returns: total discounted reward """ @@ -217,9 +223,11 @@ def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], _) -> OrderedD """ Carries out a single step through the environment to update the replay buffer. Then calculates loss based on the minibatch recieved + Args: batch: current mini batch of replay data _: batch number, not used + Returns: Training loss and log metrics """ @@ -265,9 +273,12 @@ def get_device(self, batch) -> str: def add_model_specific_args(arg_parser) -> argparse.ArgumentParser: """ Adds arguments for DQN model + Note: these params are fine tuned for Pong env + Args: arg_parser: the current argument parser to add to + Returns: arg_parser with model specific cargs added """ diff --git a/pl_bolts/models/rl/vanilla_policy_gradient_model.py b/pl_bolts/models/rl/vanilla_policy_gradient_model.py index 89d9df8cd5..c3e1fa73fc 100644 --- a/pl_bolts/models/rl/vanilla_policy_gradient_model.py +++ b/pl_bolts/models/rl/vanilla_policy_gradient_model.py @@ -109,8 +109,10 @@ def __init__( def forward(self, x: torch.Tensor) -> torch.Tensor: """ Passes in a state x through the network and gets the q_values of each action as an output + Args: x: environment state + Returns: q values """ @@ -208,9 +210,11 @@ def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], _) -> OrderedD """ Carries out a single step through the environment to update the replay buffer. Then calculates loss based on the minibatch recieved + Args: batch: current mini batch of replay data _: batch number, not used + Returns: Training loss and log metrics """ @@ -255,9 +259,12 @@ def get_device(self, batch) -> str: def add_model_specific_args(arg_parser) -> argparse.ArgumentParser: """ Adds arguments for DQN model + Note: these params are fine tuned for Pong env + Args: arg_parser: the current argument parser to add to + Returns: arg_parser with model specific cargs added """ diff --git a/pl_bolts/models/self_supervised/resnets.py b/pl_bolts/models/self_supervised/resnets.py index 6cd3e5f683..601307a4d3 100644 --- a/pl_bolts/models/self_supervised/resnets.py +++ b/pl_bolts/models/self_supervised/resnets.py @@ -275,6 +275,7 @@ def _resnet(arch, block, layers, pretrained, progress, **kwargs): def resnet18(pretrained=False, progress=True, **kwargs): r"""ResNet-18 model from `"Deep Residual Learning for Image Recognition" `_ + Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr @@ -286,6 +287,7 @@ def resnet18(pretrained=False, progress=True, **kwargs): def resnet34(pretrained=False, progress=True, **kwargs): r"""ResNet-34 model from `"Deep Residual Learning for Image Recognition" `_ + Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr @@ -297,6 +299,7 @@ def resnet34(pretrained=False, progress=True, **kwargs): def resnet50(pretrained=False, progress=True, **kwargs): r"""ResNet-50 model from `"Deep Residual Learning for Image Recognition" `_ + Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr @@ -308,6 +311,7 @@ def resnet50(pretrained=False, progress=True, **kwargs): def resnet50_bn(pretrained=False, progress=True, **kwargs): r"""ResNet-50 model from `"Deep Residual Learning for Image Recognition" `_ + Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr @@ -319,6 +323,7 @@ def resnet50_bn(pretrained=False, progress=True, **kwargs): def resnet101(pretrained=False, progress=True, **kwargs): r"""ResNet-101 model from `"Deep Residual Learning for Image Recognition" `_ + Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr @@ -330,6 +335,7 @@ def resnet101(pretrained=False, progress=True, **kwargs): def resnet152(pretrained=False, progress=True, **kwargs): r"""ResNet-152 model from `"Deep Residual Learning for Image Recognition" `_ + Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr @@ -341,6 +347,7 @@ def resnet152(pretrained=False, progress=True, **kwargs): def resnext50_32x4d(pretrained=False, progress=True, **kwargs): r"""ResNeXt-50 32x4d model from `"Aggregated Residual Transformation for Deep Neural Networks" `_ + Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr @@ -354,6 +361,7 @@ def resnext50_32x4d(pretrained=False, progress=True, **kwargs): def resnext101_32x8d(pretrained=False, progress=True, **kwargs): r"""ResNeXt-101 32x8d model from `"Aggregated Residual Transformation for Deep Neural Networks" `_ + Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr @@ -371,6 +379,7 @@ def wide_resnet50_2(pretrained=False, progress=True, **kwargs): which is twice larger in every block. The number of channels in outer 1x1 convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 channels, and in Wide ResNet-50-2 has 2048-1024-2048. + Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr @@ -387,6 +396,7 @@ def wide_resnet101_2(pretrained=False, progress=True, **kwargs): which is twice larger in every block. The number of channels in outer 1x1 convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048 channels, and in Wide ResNet-50-2 has 2048-1024-2048. + Args: pretrained (bool): If True, returns a model pre-trained on ImageNet progress (bool): If True, displays a progress bar of the download to stderr