From d756a41e0bbd28028362e5f218b13f949157e91a Mon Sep 17 00:00:00 2001 From: Zhiqiang Wang Date: Mon, 8 Feb 2021 21:01:33 +0800 Subject: [PATCH] Refactor Lightning DetectionDataModule (#48) * Switch on the PyTorchLightning unittest * Rename to detection_datamodule.py * Rename to train.py and predict.py * Add lightning_logs to gitignore * Refactoring VOCDetectionDataModule * Bug fixes in VOCDetectionDataModule * Fix targets batch in collate_fn * Fix device inconsistency problem * Fix unittest of engine * Rearrange VOCDetectionDataModule location * Add CocoDetectionDataModule --- .gitignore | 1 + datasets/__init__.py | 3 +- datasets/coco.py | 17 ----- datasets/pl_datamodule.py | 149 ++++++++++++++++++++++++++++++++++++++ datasets/pl_wrapper.py | 114 ----------------------------- datasets/transforms.py | 59 +++++++-------- datasets/voc.py | 14 ---- models/pl_wrapper.py | 12 +-- models/transform.py | 8 +- detect.py => predict.py | 4 +- test/dataset_utils.py | 29 +------- test/test_engine.py | 18 ++--- main.py => train.py | 28 ++++--- 13 files changed, 218 insertions(+), 238 deletions(-) create mode 100644 datasets/pl_datamodule.py delete mode 100644 datasets/pl_wrapper.py rename detect.py => predict.py (97%) rename main.py => train.py (69%) diff --git a/.gitignore b/.gitignore index 69a79eff..252eaa13 100644 --- a/.gitignore +++ b/.gitignore @@ -2,6 +2,7 @@ data-bin checkpoints logs +lightning_logs *.ipynb runs yolov5s.pt diff --git a/datasets/__init__.py b/datasets/__init__.py index ea24b6a8..ff7c7605 100644 --- a/datasets/__init__.py +++ b/datasets/__init__.py @@ -1,3 +1,2 @@ # Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved. - -from .pl_wrapper import collate_fn, DetectionDataModule +from .pl_datamodule import DetectionDataModule, VOCDetectionDataModule, CocoDetectionDataModule diff --git a/datasets/coco.py b/datasets/coco.py index 9f172650..243c787a 100644 --- a/datasets/coco.py +++ b/datasets/coco.py @@ -10,8 +10,6 @@ import torchvision from pycocotools import mask as coco_mask -from .transforms import make_transforms - class ConvertCocoPolysToMask(object): def __init__(self, json_category_id_maps, return_masks=False): @@ -156,18 +154,3 @@ def _has_valid_annotation(anno): dataset = torch.utils.data.Subset(dataset, ids) return dataset - - -def build(data_path, image_set, year): - ann_file = Path(data_path).joinpath("annotations").joinpath(f"instances_{image_set}{year}.json") - - dataset = CocoDetection( - data_path, - ann_file, - transforms=make_transforms(image_set=image_set), - ) - - if image_set == 'train': - dataset = _coco_remove_images_without_annotations(dataset) - - return dataset diff --git a/datasets/pl_datamodule.py b/datasets/pl_datamodule.py new file mode 100644 index 00000000..d8f2c000 --- /dev/null +++ b/datasets/pl_datamodule.py @@ -0,0 +1,149 @@ +# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved. +from pathlib import Path + +import torch.utils.data +from torch.utils.data import DataLoader +from torch.utils.data.dataset import Dataset + +from pytorch_lightning import LightningDataModule + +from .transforms import collate_fn, default_train_transforms, default_val_transforms +from .voc import VOCDetection +from .coco import CocoDetection + +from typing import Callable, List, Any, Optional + + +class DetectionDataModule(LightningDataModule): + """ + Wrapper of Datasets in LightningDataModule + """ + def __init__( + self, + train_dataset: Optional[Dataset] = None, + val_dataset: Optional[Dataset] = None, + test_dataset: Optional[Dataset] = None, + batch_size: int = 1, + num_workers: int = 0, + *args: Any, + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + + self._train_dataset = train_dataset + self._val_dataset = val_dataset + self._test_dataset = test_dataset + + self.batch_size = batch_size + self.num_workers = num_workers + + def train_dataloader(self, batch_size: int = 16) -> None: + """ + VOCDetection and CocoDetection + Args: + batch_size: size of batch + transforms: custom transforms + """ + # Creating data loaders + sampler = torch.utils.data.RandomSampler(self._train_dataset) + batch_sampler = torch.utils.data.BatchSampler(sampler, batch_size, drop_last=True) + + loader = DataLoader( + self._train_dataset, + batch_sampler=batch_sampler, + collate_fn=collate_fn, + num_workers=self.num_workers, + ) + + return loader + + def val_dataloader(self, batch_size: int = 16) -> None: + """ + VOCDetection and CocoDetection + Args: + batch_size: size of batch + transforms: custom transforms + """ + # Creating data loaders + sampler = torch.utils.data.SequentialSampler(self._val_dataset) + + loader = DataLoader( + self._val_dataset, + batch_size, + sampler=sampler, + drop_last=False, + collate_fn=collate_fn, + num_workers=self.num_workers, + ) + + return loader + + +class VOCDetectionDataModule(DetectionDataModule): + def __init__( + self, + data_path: str, + years: List[str] = ["2007", "2012"], + train_transform: Optional[Callable] = default_train_transforms, + val_transform: Optional[Callable] = default_val_transforms, + batch_size: int = 1, + num_workers: int = 0, + *args: Any, + **kwargs: Any, + ) -> None: + train_dataset, num_classes = self.build_datasets( + data_path, image_set='train', years=years, transforms=train_transform) + val_dataset, _ = self.build_datasets( + data_path, image_set='val', years=years, transforms=val_transform) + + super().__init__(train_dataset=train_dataset, val_dataset=val_dataset, + batch_size=batch_size, num_workers=num_workers, *args, **kwargs) + + self.num_classes = num_classes + + @staticmethod + def build_datasets(data_path, image_set, years, transforms): + datasets = [] + for year in years: + dataset = VOCDetection( + data_path, + year=year, + image_set=image_set, + transforms=transforms(), + ) + datasets.append(dataset) + + num_classes = len(datasets[0].prepare.CLASSES) + + if len(datasets) == 1: + return datasets[0], num_classes + else: + return torch.utils.data.ConcatDataset(datasets), num_classes + + +class CocoDetectionDataModule(DetectionDataModule): + def __init__( + self, + data_path: str, + year: str = "2017", + train_transform: Optional[Callable] = default_train_transforms, + val_transform: Optional[Callable] = default_val_transforms, + batch_size: int = 1, + num_workers: int = 0, + *args: Any, + **kwargs: Any, + ) -> None: + train_dataset = self.build_datasets( + data_path, image_set='train', year=year, transforms=train_transform) + val_dataset = self.build_datasets( + data_path, image_set='val', year=year, transforms=val_transform) + + super().__init__(train_dataset=train_dataset, val_dataset=val_dataset, + batch_size=batch_size, num_workers=num_workers, *args, **kwargs) + + self.num_classes = 80 + + @staticmethod + def build_datasets(data_path, image_set, year, transforms): + ann_file = Path(data_path).joinpath('annotations').joinpath(f"instances_{image_set}{year}.json") + return CocoDetection(data_path, ann_file, transforms()) diff --git a/datasets/pl_wrapper.py b/datasets/pl_wrapper.py deleted file mode 100644 index 6662466a..00000000 --- a/datasets/pl_wrapper.py +++ /dev/null @@ -1,114 +0,0 @@ -# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved. -import torch.utils.data -from torch.utils.data import DataLoader - -from pytorch_lightning import LightningDataModule - -from models.transform import nested_tensor_from_tensor_list -from .coco import build as build_coco -from .voc import build as build_voc - -from typing import List, Any - - -def collate_fn(batch): - batch = list(zip(*batch)) - samples = nested_tensor_from_tensor_list(batch[0]) - - targets = [] - for i, target in enumerate(batch[1]): - num_objects = len(target['labels']) - if num_objects > 0: - targets_merged = torch.full((num_objects, 6), i, dtype=torch.float32) - targets_merged[:, 1] = target['labels'] - targets_merged[:, 2:] = target['boxes'] - targets.append(targets_merged) - targets = torch.cat(targets, dim=0) - - return samples, targets - - -def build_dataset(data_path, dataset_type, image_set, dataset_year): - - datasets = [] - for year in dataset_year: - if dataset_type == 'coco': - dataset = build_coco(data_path, image_set, year) - elif dataset_type == 'voc': - dataset = build_voc(data_path, image_set, year) - else: - raise ValueError(f'dataset {dataset_type} not supported') - datasets.append(dataset) - - if len(datasets) == 1: - return datasets[0] - else: - return torch.utils.data.ConcatDataset(datasets) - - -class DetectionDataModule(LightningDataModule): - """ - Wrapper of Datasets in LightningDataModule - """ - def __init__( - self, - data_path: str, - dataset_type: str, - dataset_year: List[str], - num_workers: int = 4, - *args: Any, - **kwargs: Any, - ) -> None: - super().__init__(*args, **kwargs) - - self.data_path = data_path - self.dataset_type = dataset_type - self.dataset_year = dataset_year - self.num_workers = num_workers - - def train_dataloader(self, batch_size: int = 16) -> None: - """ - VOCDetection and CocoDetection - Args: - batch_size: size of batch - transforms: custom transforms - """ - dataset = build_dataset(self.data_path, self.dataset_type, 'train', self.dataset_year) - - # Creating data loaders - sampler = torch.utils.data.RandomSampler(dataset) - batch_sampler = torch.utils.data.BatchSampler( - sampler, batch_size, drop_last=True, - ) - - loader = DataLoader( - dataset, - batch_sampler=batch_sampler, - collate_fn=collate_fn, - num_workers=self.num_workers, - ) - - return loader - - def val_dataloader(self, batch_size: int = 16) -> None: - """ - VOCDetection and CocoDetection - Args: - batch_size: size of batch - transforms: custom transforms - """ - dataset = build_dataset(self.data_path, self.dataset_type, 'val', self.dataset_year) - - # Creating data loaders - sampler = torch.utils.data.SequentialSampler(dataset) - - loader = DataLoader( - dataset, - batch_size, - sampler=sampler, - drop_last=False, - collate_fn=collate_fn, - num_workers=self.num_workers, - ) - - return loader diff --git a/datasets/transforms.py b/datasets/transforms.py index 956a2b5d..fe4582c3 100644 --- a/datasets/transforms.py +++ b/datasets/transforms.py @@ -12,6 +12,32 @@ from torchvision.ops.boxes import box_convert +def collate_fn(batch): + return tuple(zip(*batch)) + + +def default_train_transforms(): + scales = [384, 416, 448, 480, 512, 544, 576, 608, 640, 672] + scales_for_training = [(640, 640)] + + return Compose([ + RandomHorizontalFlip(), + RandomSelect( + RandomResize(scales_for_training), + Compose([ + RandomResize(scales), + RandomSizeCrop(384, 480), + RandomResize(scales_for_training), + ]) + ), + Compose([ToTensor(), Normalize()]), + ]) + + +def default_val_transforms(): + return Compose([ToTensor(), Normalize()]) + + def crop(image, target, region): cropped_image = F.crop(image, *region) @@ -271,36 +297,3 @@ def __repr__(self): format_string += " {0}".format(t) format_string += "\n)" return format_string - - -def make_transforms(image_set='train'): - - normalize = Compose([ - ToTensor(), - Normalize(), - ]) - - scales = [384, 416, 448, 480, 512, 544, 576, 608, 640, 672] - scales_for_training = [(640, 640)] - - if image_set == 'train' or image_set == 'trainval': - return Compose([ - RandomHorizontalFlip(), - RandomSelect( - RandomResize(scales_for_training), - Compose([ - RandomResize(scales), - RandomSizeCrop(384, 480), - RandomResize(scales_for_training), - ]) - ), - normalize, - ]) - - if image_set == 'val' or image_set == 'test': - return Compose([ - RandomResize([800], max_size=1333), - normalize, - ]) - - raise ValueError(f'unknown {image_set}') diff --git a/datasets/voc.py b/datasets/voc.py index 5ce6e7ec..bb9db6d2 100644 --- a/datasets/voc.py +++ b/datasets/voc.py @@ -1,8 +1,6 @@ import torch import torchvision -from .transforms import make_transforms - class ConvertVOCtoCOCO(object): @@ -70,15 +68,3 @@ def __getitem__(self, index): img, target = self._transforms(img, target) return img, target - - -def build(data_path, image_set, year): - - dataset = VOCDetection( - data_path, - year=year, - image_set=image_set, - transforms=make_transforms(image_set=image_set), - ) - - return dataset diff --git a/models/pl_wrapper.py b/models/pl_wrapper.py index bfd16a3b..9aff2cf2 100644 --- a/models/pl_wrapper.py +++ b/models/pl_wrapper.py @@ -7,14 +7,14 @@ import pytorch_lightning as pl from . import yolo -from .transform import GeneralizedYOLOTransform, nested_tensor_from_tensor_list +from .transform import GeneralizedYOLOTransform from typing import Any, List, Dict, Tuple, Optional class YOLOLitWrapper(pl.LightningModule): """ - PyTorch Lightning implementation of `YOLO` + PyTorch Lightning wrapper of `YOLO` """ def __init__( self, @@ -78,9 +78,11 @@ def forward( return detections def training_step(self, batch, batch_idx): - - samples, targets = batch - + """ + The training step. + """ + # Transform the input + samples, targets = self.transform(*batch) # yolov5 takes both images and targets for training, returns loss_dict = self.model(samples.tensors, targets) loss = sum(loss for loss in loss_dict.values()) diff --git a/models/transform.py b/models/transform.py index c9ad13f7..47c463a4 100644 --- a/models/transform.py +++ b/models/transform.py @@ -2,7 +2,7 @@ # Modified by Zhiqiang Wang (zhiqwang@outlook.com) import math import torch -from torch import nn, Tensor +from torch import device, nn, Tensor import torch.nn.functional as F import torchvision @@ -57,7 +57,7 @@ def forward( images: List[Tensor], targets: Optional[List[Dict[str, Tensor]]], ) -> Tuple[NestedTensor, Optional[Tensor]]: - + device = images[0].device images = [img for img in images] if targets is not None: # make a copy of targets to avoid modifying it in-place @@ -68,7 +68,7 @@ def forward( for t in targets: data: Dict[str, Tensor] = {} for k, v in t.items(): - data[k] = v + data[k] = v.to(device) targets_copy.append(data) targets = targets_copy @@ -99,7 +99,7 @@ def forward( for i, target in enumerate(targets): num_objects = len(target['labels']) if num_objects > 0: - targets_merged = torch.full((num_objects, 6), i, dtype=torch.float32) + targets_merged = torch.full((num_objects, 6), i, dtype=torch.float32, device=device) targets_merged[:, 1] = target['labels'] targets_merged[:, 2:] = target['boxes'] targets_batched.append(targets_merged) diff --git a/detect.py b/predict.py similarity index 97% rename from detect.py rename to predict.py index cbad208a..f20fa67d 100644 --- a/detect.py +++ b/predict.py @@ -7,7 +7,7 @@ import torch from utils.image_utils import read_image, load_names, overlay_boxes -from hubconf import yolov5_darknet_pan_s_v31 +from hubconf import yolov5s @torch.no_grad() @@ -27,7 +27,7 @@ def main(args): print(args) device = torch.device("cuda") if torch.cuda.is_available() and args.gpu else torch.device("cpu") - model = yolov5_darknet_pan_s_v31( + model = yolov5s( pretrained=True, min_size=args.min_size, max_size=args.max_size, diff --git a/test/dataset_utils.py b/test/dataset_utils.py index 9b0ebeb7..45a22bbc 100644 --- a/test/dataset_utils.py +++ b/test/dataset_utils.py @@ -1,15 +1,14 @@ import random import torch -from torch.utils.data import Dataset, DataLoader +from torch.utils.data import Dataset from torchvision import ops -from datasets import collate_fn -__all__ = ["create_loaders", "DummyDetectionDataset"] +__all__ = ["DummyDetectionDataset"] -class DummyDetectionDataset(Dataset): +class DummyCOCODetectionDataset(Dataset): """ Generate a dummy dataset for detection Example:: @@ -76,25 +75,3 @@ def __getitem__(self, idx: int): boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32) image_id = torch.tensor([idx]) return img, {"image_id": image_id, "boxes": boxes, "labels": labels} - - -def create_loaders(dataset, batch_size=16, num_workers=0): - """ - Creates train loader and test loader from train and test dataset - Args: - dataset: Torchvision dataset. - batch_size (int) : Default 16, Batch size - num_workers (int) : Defualt 0, Number of workers for training and validation. - """ - sampler = torch.utils.data.RandomSampler(dataset) - - batch_sampler = torch.utils.data.BatchSampler(sampler, batch_size, drop_last=True) - - dataloader = DataLoader( - dataset, - batch_sampler=batch_sampler, - collate_fn=collate_fn, - num_workers=num_workers, - ) - - return dataloader diff --git a/test/test_engine.py b/test/test_engine.py index 0d3e21bb..542a40d0 100644 --- a/test/test_engine.py +++ b/test/test_engine.py @@ -2,12 +2,13 @@ import torch import pytorch_lightning as pl -from .torch_utils import image_preprocess -from .dataset_utils import create_loaders, DummyDetectionDataset - from models import YOLOLitWrapper from models.yolo import yolov5_darknet_pan_s_r31 from models.transform import nested_tensor_from_tensor_list +from datasets import DetectionDataModule + +from .torch_utils import image_preprocess +from .dataset_utils import DummyCOCODetectionDataset from typing import Dict @@ -36,20 +37,17 @@ def test_train(self): self.assertIsInstance(out["bbox_regression"], torch.Tensor) self.assertIsInstance(out["objectness"], torch.Tensor) - @unittest.skip("Current it isn't well implemented") def test_train_one_step(self): # Load model model = YOLOLitWrapper() model.train() - # Datasets - datasets = DummyDetectionDataset(num_samples=200) - data_loader_train = create_loaders(datasets) - data_loader_val = create_loaders(datasets) - + # Setup the DataModule + train_dataset = DummyCOCODetectionDataset(num_samples=128) + datamodule = DetectionDataModule(train_dataset, batch_size=16) # Trainer trainer = pl.Trainer(max_epochs=1) - trainer.fit(model, data_loader_train, data_loader_val) + trainer.fit(model, datamodule) def test_inference(self): # Infer over an image diff --git a/main.py b/train.py similarity index 69% rename from main.py rename to train.py index bc4b1618..f478c015 100644 --- a/main.py +++ b/train.py @@ -5,20 +5,22 @@ import pytorch_lightning as pl -from datasets import DetectionDataModule -from models import YOLOLitWrapper +from datasets import VOCDetectionDataModule +import models def get_args_parser(): parser = argparse.ArgumentParser('You only look once detector', add_help=False) + parser.add_argument('--arch', default='yolov5s', + help='model structure to train') parser.add_argument('--data_path', default='./data-bin', help='dataset') parser.add_argument('--dataset_type', default='coco', help='dataset') parser.add_argument('--dataset_mode', default='instances', help='dataset mode') - parser.add_argument('--dataset_year', default=['2017'], nargs='+', + parser.add_argument('--years', default=['2017'], nargs='+', help='dataset year') parser.add_argument('--train_set', default='train', help='set of train') @@ -26,8 +28,10 @@ def get_args_parser(): help='set of val') parser.add_argument('--batch_size', default=32, type=int, help='images per gpu, the total batch size is $NGPU x batch_size') - parser.add_argument('--epochs', default=26, type=int, metavar='N', + parser.add_argument('--max_epochs', default=1, type=int, metavar='N', help='number of total epochs to run') + parser.add_argument('--num_gpus', default=1, type=int, metavar='N', + help='number of gpu utilizing (default: 1)') parser.add_argument('--num_workers', default=4, type=int, metavar='N', help='number of data loading workers (default: 4)') parser.add_argument('--print_freq', default=20, type=int, @@ -38,15 +42,16 @@ def get_args_parser(): def main(args): + # Load the data + datamodule = VOCDetectionDataModule.from_argparse_args(args) - # Load model - model = YOLOLitWrapper() - model.train() - datamodule = DetectionDataModule.from_argparse_args(args) + # Build the model + model = models.__dict__[args.arch](num_classes=datamodule.num_classes) - # train - # trainer = pl.Trainer().from_argparse_args(args) - trainer = pl.Trainer(max_epochs=1, gpus=1) + # Create the trainer. Run twice on data + trainer = pl.Trainer(max_epochs=args.max_epochs, gpus=args.num_gpus) + + # Train the model trainer.fit(model, datamodule=datamodule) @@ -55,4 +60,5 @@ def main(args): args = parser.parse_args() if args.output_dir: Path(args.output_dir).mkdir(parents=True, exist_ok=True) + main(args)