From ff21d8312c6e7f1af6305b7b2482d623773420e5 Mon Sep 17 00:00:00 2001 From: zhiqwang Date: Mon, 1 Feb 2021 04:28:38 -0500 Subject: [PATCH 1/5] Rename lightning_wrapper to pl_wrapper --- models/__init__.py | 4 ++-- models/{lightning_wrapper.py => pl_wrapper.py} | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) rename models/{lightning_wrapper.py => pl_wrapper.py} (98%) diff --git a/models/__init__.py b/models/__init__.py index 76a64349..b57ffc1c 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -1,9 +1,9 @@ -# Copyright (c) 2020, Zhiqiang Wang. All Rights Reserved. +# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved. from torch import nn from .common import Conv from utils.activations import Hardswish -from .lightning_wrapper import YOLOLitWrapper +from .pl_wrapper import YOLOLitWrapper def yolov5s(**kwargs): diff --git a/models/lightning_wrapper.py b/models/pl_wrapper.py similarity index 98% rename from models/lightning_wrapper.py rename to models/pl_wrapper.py index f5f0c542..d07b69f9 100644 --- a/models/lightning_wrapper.py +++ b/models/pl_wrapper.py @@ -10,7 +10,7 @@ from . import yolo from .transform import nested_tensor_from_tensor_list -from typing import Tuple, Any, List, Dict, Optional +from typing import Any, List, Optional class YOLOLitWrapper(pl.LightningModule): From 9ec3196c511dcdda98e7a8ea3ba5a34d3ec2ca67 Mon Sep 17 00:00:00 2001 From: zhiqwang Date: Mon, 1 Feb 2021 12:48:05 -0500 Subject: [PATCH 2/5] Refactor datasets follows the LightningDataModule --- datasets/__init__.py | 54 +------------------ datasets/coco.py | 10 ++-- datasets/pl_wrapper.py | 114 +++++++++++++++++++++++++++++++++++++++++ datasets/voc.py | 4 +- main.py | 39 ++------------ test/test_engine.py | 1 + 6 files changed, 126 insertions(+), 96 deletions(-) create mode 100644 datasets/pl_wrapper.py diff --git a/datasets/__init__.py b/datasets/__init__.py index aaed14d4..6c46b12e 100644 --- a/datasets/__init__.py +++ b/datasets/__init__.py @@ -1,53 +1 @@ -# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved -import torch.utils.data -import torchvision - -from models.transform import nested_tensor_from_tensor_list - -from .coco import build as build_coco -from .voc import build as build_voc - - -def get_coco_api_from_dataset(dataset): - for _ in range(10): - # if isinstance(dataset, torchvision.datasets.CocoDetection): - # break - if isinstance(dataset, torch.utils.data.Subset): - dataset = dataset.dataset - if isinstance(dataset, torchvision.datasets.CocoDetection): - return dataset.coco - - -def collate_fn(batch): - batch = list(zip(*batch)) - samples = nested_tensor_from_tensor_list(batch[0]) - - targets = [] - for i, target in enumerate(batch[1]): - num_objects = len(target['labels']) - if num_objects > 0: - targets_merged = torch.full((num_objects, 6), i, dtype=torch.float32) - targets_merged[:, 1] = target['labels'] - targets_merged[:, 2:] = target['boxes'] - targets.append(targets_merged) - targets = torch.cat(targets, dim=0) - - return samples, targets - - -def build_dataset(image_set, dataset_year, args): - - datasets = [] - for year in dataset_year: - if args.dataset_file == 'coco': - dataset = build_coco(image_set, year, args) - elif args.dataset_file == 'voc': - dataset = build_voc(image_set, year, args) - else: - raise ValueError(f'dataset {args.dataset_file} not supported') - datasets.append(dataset) - - if len(datasets) == 1: - return datasets[0] - else: - return torch.utils.data.ConcatDataset(datasets) +# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved. diff --git a/datasets/coco.py b/datasets/coco.py index 99b8cb64..ef831ce5 100644 --- a/datasets/coco.py +++ b/datasets/coco.py @@ -82,7 +82,7 @@ def __call__(self, image, target): class CocoDetection(torchvision.datasets.CocoDetection): - def __init__(self, img_folder, ann_file, transforms, return_masks): + def __init__(self, img_folder, ann_file, transforms, return_masks=False): super().__init__(img_folder, ann_file) self._transforms = transforms @@ -158,19 +158,17 @@ def _has_valid_annotation(anno): return dataset -def build(image_set, year, args): - root = Path(args.data_path) +def build(data_path, image_set, year): + root = Path(data_path) assert root.exists(), f'provided COCO path {root} does not exist' - mode = args.dataset_mode img_folder = Path(root) - ann_file = img_folder.joinpath("annotations").joinpath(f"{mode}_{image_set}{year}.json") + ann_file = img_folder.joinpath("annotations").joinpath(f"instances_{image_set}{year}.json") dataset = CocoDetection( img_folder, ann_file, transforms=make_transforms(image_set=image_set), - return_masks=args.masks, ) if image_set == 'train': diff --git a/datasets/pl_wrapper.py b/datasets/pl_wrapper.py new file mode 100644 index 00000000..a8bdf1b5 --- /dev/null +++ b/datasets/pl_wrapper.py @@ -0,0 +1,114 @@ +# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved. +import torch.utils.data +from torch.utils.data import DataLoader + +from pytorch_lightning import LightningDataModule + +from models.transform import nested_tensor_from_tensor_list +from .coco import build as build_coco +from .voc import build as build_voc + +from typing import List, Any + + +def _collate_fn(batch): + batch = list(zip(*batch)) + samples = nested_tensor_from_tensor_list(batch[0]) + + targets = [] + for i, target in enumerate(batch[1]): + num_objects = len(target['labels']) + if num_objects > 0: + targets_merged = torch.full((num_objects, 6), i, dtype=torch.float32) + targets_merged[:, 1] = target['labels'] + targets_merged[:, 2:] = target['boxes'] + targets.append(targets_merged) + targets = torch.cat(targets, dim=0) + + return samples, targets + + +def build_dataset(data_path, dataset_type, image_set, dataset_year): + + datasets = [] + for year in dataset_year: + if dataset_type == 'coco': + dataset = build_coco(data_path, image_set, year) + elif dataset_type == 'voc': + dataset = build_voc(data_path, image_set, year) + else: + raise ValueError(f'dataset {dataset_type} not supported') + datasets.append(dataset) + + if len(datasets) == 1: + return datasets[0] + else: + return torch.utils.data.ConcatDataset(datasets) + + +class DetectionDataModule(LightningDataModule): + """ + Wrapper of Datasets in LightningDataModule + """ + def __init__( + self, + data_path: str, + dataset_type: str, + dataset_year: List[str], + num_workers: int = 4, + *args: Any, + **kwargs: Any, + ) -> None: + super().__init__(*args, **kwargs) + + self.data_path = data_path + self.dataset_type = dataset_type + self.dataset_year = dataset_year + self.num_workers = num_workers + + def train_dataloader(self, batch_size: int = 16) -> None: + """ + VOCDetection and CocoDetection + Args: + batch_size: size of batch + transforms: custom transforms + """ + dataset = build_dataset(self.data_path, self.dataset_type, 'train', self.dataset_year) + + # Creating data loaders + sampler = torch.utils.data.RandomSampler(dataset) + batch_sampler = torch.utils.data.BatchSampler( + sampler, batch_size, drop_last=True, + ) + + loader = DataLoader( + dataset, + batch_sampler=batch_sampler, + collate_fn=_collate_fn, + num_workers=self.num_workers, + ) + + return loader + + def val_dataloader(self, batch_size: int = 16) -> None: + """ + VOCDetection and CocoDetection + Args: + batch_size: size of batch + transforms: custom transforms + """ + dataset = build_dataset(self.data_path, self.dataset_type, 'val', self.dataset_year) + + # Creating data loaders + sampler = torch.utils.data.SequentialSampler(dataset) + + loader = DataLoader( + dataset, + batch_size, + sampler=sampler, + drop_last=False, + collate_fn=_collate_fn, + num_workers=self.num_workers, + ) + + return loader diff --git a/datasets/voc.py b/datasets/voc.py index 813ff34b..5fdd50f2 100644 --- a/datasets/voc.py +++ b/datasets/voc.py @@ -72,10 +72,10 @@ def __getitem__(self, index): return img, target -def build(image_set, year, args): +def build(data_path, image_set, year): dataset = VOCDetection( - img_folder=args.data_path, + img_folder=data_path, year=year, image_set=image_set, transforms=make_transforms(image_set=image_set), diff --git a/main.py b/main.py index 33bec18e..e1d30917 100644 --- a/main.py +++ b/main.py @@ -1,9 +1,7 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. # Modified by Zhiqiang Wang (zhiqwang@foxmail.com) -import datetime import argparse -import time from pathlib import Path import torch @@ -11,7 +9,7 @@ import pytorch_lightning as pl -from datasets import build_dataset, get_coco_api_from_dataset, collate_fn +from datasets.pl_wrapper import DetectionDataModule from models import YOLOLitWrapper @@ -20,7 +18,7 @@ def get_args_parser(): parser.add_argument('--data_path', default='./data-bin', help='dataset') - parser.add_argument('--dataset_file', default='coco', + parser.add_argument('--dataset_type', default='coco', help='dataset') parser.add_argument('--dataset_mode', default='instances', help='dataset mode') @@ -44,47 +42,18 @@ def get_args_parser(): def main(args): - - # Data loading code - print('Loading data') - dataset_train = build_dataset(args.train_set, args.dataset_year, args) - dataset_val = build_dataset(args.val_set, args.dataset_year, args) - base_ds = get_coco_api_from_dataset(dataset_val) - - print('Creating data loaders') - sampler_train = torch.utils.data.RandomSampler(dataset_train) - sampler_val = torch.utils.data.SequentialSampler(dataset_val) - - batch_sampler_train = torch.utils.data.BatchSampler( - sampler_train, args.batch_size, drop_last=True, - ) - - data_loader_train = DataLoader( - dataset_train, - batch_sampler=batch_sampler_train, - collate_fn=collate_fn, - num_workers=args.num_workers, - ) - data_loader_val = DataLoader( - dataset_val, - args.batch_size, - sampler=sampler_val, - drop_last=False, - collate_fn=collate_fn, - num_workers=args.num_workers, - ) - print('Creating model, always set args.return_criterion be True') args.return_criterion = True # Load model model = YOLOLitWrapper() model.train() + datamodule = DetectionDataModule.from_argparse_args(args) # train # trainer = pl.Trainer().from_argparse_args(args) trainer = pl.Trainer(max_epochs=1, gpus=1) - trainer.fit(model, data_loader_train, data_loader_val) + trainer.fit(model, datamodule=datamodule) if __name__ == "__main__": diff --git a/test/test_engine.py b/test/test_engine.py index 288e1cf4..3f7d17e7 100644 --- a/test/test_engine.py +++ b/test/test_engine.py @@ -4,6 +4,7 @@ from .torch_utils import image_preprocess from .dataset_utils import create_loaders, DummyDetectionDataset + from models import YOLOLitWrapper from typing import Dict From fe65e3a661df3a0e825c55289a6b4f20b874341d Mon Sep 17 00:00:00 2001 From: zhiqwang Date: Mon, 1 Feb 2021 12:52:07 -0500 Subject: [PATCH 3/5] Remove unused codes --- main.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/main.py b/main.py index e1d30917..1f47fe24 100644 --- a/main.py +++ b/main.py @@ -1,12 +1,8 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. # Modified by Zhiqiang Wang (zhiqwang@foxmail.com) - import argparse from pathlib import Path -import torch -from torch.utils.data import DataLoader, DistributedSampler - import pytorch_lightning as pl from datasets.pl_wrapper import DetectionDataModule @@ -42,8 +38,6 @@ def get_args_parser(): def main(args): - print('Creating model, always set args.return_criterion be True') - args.return_criterion = True # Load model model = YOLOLitWrapper() From 777b4c50f1fb3fa11c8d778978b286df37b9fb24 Mon Sep 17 00:00:00 2001 From: zhiqwang Date: Mon, 1 Feb 2021 13:02:02 -0500 Subject: [PATCH 4/5] Fixing unittest --- datasets/__init__.py | 2 ++ datasets/pl_wrapper.py | 6 +++--- main.py | 2 +- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/datasets/__init__.py b/datasets/__init__.py index 6c46b12e..ea24b6a8 100644 --- a/datasets/__init__.py +++ b/datasets/__init__.py @@ -1 +1,3 @@ # Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved. + +from .pl_wrapper import collate_fn, DetectionDataModule diff --git a/datasets/pl_wrapper.py b/datasets/pl_wrapper.py index a8bdf1b5..6662466a 100644 --- a/datasets/pl_wrapper.py +++ b/datasets/pl_wrapper.py @@ -11,7 +11,7 @@ from typing import List, Any -def _collate_fn(batch): +def collate_fn(batch): batch = list(zip(*batch)) samples = nested_tensor_from_tensor_list(batch[0]) @@ -84,7 +84,7 @@ def train_dataloader(self, batch_size: int = 16) -> None: loader = DataLoader( dataset, batch_sampler=batch_sampler, - collate_fn=_collate_fn, + collate_fn=collate_fn, num_workers=self.num_workers, ) @@ -107,7 +107,7 @@ def val_dataloader(self, batch_size: int = 16) -> None: batch_size, sampler=sampler, drop_last=False, - collate_fn=_collate_fn, + collate_fn=collate_fn, num_workers=self.num_workers, ) diff --git a/main.py b/main.py index 1f47fe24..bc4b1618 100644 --- a/main.py +++ b/main.py @@ -5,7 +5,7 @@ import pytorch_lightning as pl -from datasets.pl_wrapper import DetectionDataModule +from datasets import DetectionDataModule from models import YOLOLitWrapper From b29f94d30dfce1d5d3ecbc53e816c16878e85327 Mon Sep 17 00:00:00 2001 From: zhiqwang Date: Mon, 1 Feb 2021 13:20:23 -0500 Subject: [PATCH 5/5] Cleanup codes --- datasets/coco.py | 8 ++------ datasets/voc.py | 2 +- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/datasets/coco.py b/datasets/coco.py index ef831ce5..9f172650 100644 --- a/datasets/coco.py +++ b/datasets/coco.py @@ -159,14 +159,10 @@ def _has_valid_annotation(anno): def build(data_path, image_set, year): - root = Path(data_path) - assert root.exists(), f'provided COCO path {root} does not exist' - - img_folder = Path(root) - ann_file = img_folder.joinpath("annotations").joinpath(f"instances_{image_set}{year}.json") + ann_file = Path(data_path).joinpath("annotations").joinpath(f"instances_{image_set}{year}.json") dataset = CocoDetection( - img_folder, + data_path, ann_file, transforms=make_transforms(image_set=image_set), ) diff --git a/datasets/voc.py b/datasets/voc.py index 5fdd50f2..5ce6e7ec 100644 --- a/datasets/voc.py +++ b/datasets/voc.py @@ -75,7 +75,7 @@ def __getitem__(self, index): def build(data_path, image_set, year): dataset = VOCDetection( - img_folder=data_path, + data_path, year=year, image_set=image_set, transforms=make_transforms(image_set=image_set),