Skip to content

Commit

Permalink
Refactor Lightning DetectionDataModule (#48)
Browse files Browse the repository at this point in the history
* Switch on the PyTorchLightning unittest

* Rename to detection_datamodule.py

* Rename to train.py and predict.py

* Add lightning_logs to gitignore

* Refactoring VOCDetectionDataModule

* Bug fixes in VOCDetectionDataModule

* Fix targets batch in collate_fn

* Fix device inconsistency problem

* Fix unittest of engine

* Rearrange VOCDetectionDataModule location

* Add CocoDetectionDataModule
  • Loading branch information
zhiqwang authored Feb 8, 2021
1 parent 38bdf06 commit d756a41
Show file tree
Hide file tree
Showing 13 changed files with 218 additions and 238 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
data-bin
checkpoints
logs
lightning_logs
*.ipynb
runs
yolov5s.pt
Expand Down
3 changes: 1 addition & 2 deletions datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved.

from .pl_wrapper import collate_fn, DetectionDataModule
from .pl_datamodule import DetectionDataModule, VOCDetectionDataModule, CocoDetectionDataModule
17 changes: 0 additions & 17 deletions datasets/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
import torchvision
from pycocotools import mask as coco_mask

from .transforms import make_transforms


class ConvertCocoPolysToMask(object):
def __init__(self, json_category_id_maps, return_masks=False):
Expand Down Expand Up @@ -156,18 +154,3 @@ def _has_valid_annotation(anno):

dataset = torch.utils.data.Subset(dataset, ids)
return dataset


def build(data_path, image_set, year):
ann_file = Path(data_path).joinpath("annotations").joinpath(f"instances_{image_set}{year}.json")

dataset = CocoDetection(
data_path,
ann_file,
transforms=make_transforms(image_set=image_set),
)

if image_set == 'train':
dataset = _coco_remove_images_without_annotations(dataset)

return dataset
149 changes: 149 additions & 0 deletions datasets/pl_datamodule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved.
from pathlib import Path

import torch.utils.data
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset

from pytorch_lightning import LightningDataModule

from .transforms import collate_fn, default_train_transforms, default_val_transforms
from .voc import VOCDetection
from .coco import CocoDetection

from typing import Callable, List, Any, Optional


class DetectionDataModule(LightningDataModule):
"""
Wrapper of Datasets in LightningDataModule
"""
def __init__(
self,
train_dataset: Optional[Dataset] = None,
val_dataset: Optional[Dataset] = None,
test_dataset: Optional[Dataset] = None,
batch_size: int = 1,
num_workers: int = 0,
*args: Any,
**kwargs: Any,
) -> None:
super().__init__(*args, **kwargs)

self._train_dataset = train_dataset
self._val_dataset = val_dataset
self._test_dataset = test_dataset

self.batch_size = batch_size
self.num_workers = num_workers

def train_dataloader(self, batch_size: int = 16) -> None:
"""
VOCDetection and CocoDetection
Args:
batch_size: size of batch
transforms: custom transforms
"""
# Creating data loaders
sampler = torch.utils.data.RandomSampler(self._train_dataset)
batch_sampler = torch.utils.data.BatchSampler(sampler, batch_size, drop_last=True)

loader = DataLoader(
self._train_dataset,
batch_sampler=batch_sampler,
collate_fn=collate_fn,
num_workers=self.num_workers,
)

return loader

def val_dataloader(self, batch_size: int = 16) -> None:
"""
VOCDetection and CocoDetection
Args:
batch_size: size of batch
transforms: custom transforms
"""
# Creating data loaders
sampler = torch.utils.data.SequentialSampler(self._val_dataset)

loader = DataLoader(
self._val_dataset,
batch_size,
sampler=sampler,
drop_last=False,
collate_fn=collate_fn,
num_workers=self.num_workers,
)

return loader


class VOCDetectionDataModule(DetectionDataModule):
def __init__(
self,
data_path: str,
years: List[str] = ["2007", "2012"],
train_transform: Optional[Callable] = default_train_transforms,
val_transform: Optional[Callable] = default_val_transforms,
batch_size: int = 1,
num_workers: int = 0,
*args: Any,
**kwargs: Any,
) -> None:
train_dataset, num_classes = self.build_datasets(
data_path, image_set='train', years=years, transforms=train_transform)
val_dataset, _ = self.build_datasets(
data_path, image_set='val', years=years, transforms=val_transform)

super().__init__(train_dataset=train_dataset, val_dataset=val_dataset,
batch_size=batch_size, num_workers=num_workers, *args, **kwargs)

self.num_classes = num_classes

@staticmethod
def build_datasets(data_path, image_set, years, transforms):
datasets = []
for year in years:
dataset = VOCDetection(
data_path,
year=year,
image_set=image_set,
transforms=transforms(),
)
datasets.append(dataset)

num_classes = len(datasets[0].prepare.CLASSES)

if len(datasets) == 1:
return datasets[0], num_classes
else:
return torch.utils.data.ConcatDataset(datasets), num_classes


class CocoDetectionDataModule(DetectionDataModule):
def __init__(
self,
data_path: str,
year: str = "2017",
train_transform: Optional[Callable] = default_train_transforms,
val_transform: Optional[Callable] = default_val_transforms,
batch_size: int = 1,
num_workers: int = 0,
*args: Any,
**kwargs: Any,
) -> None:
train_dataset = self.build_datasets(
data_path, image_set='train', year=year, transforms=train_transform)
val_dataset = self.build_datasets(
data_path, image_set='val', year=year, transforms=val_transform)

super().__init__(train_dataset=train_dataset, val_dataset=val_dataset,
batch_size=batch_size, num_workers=num_workers, *args, **kwargs)

self.num_classes = 80

@staticmethod
def build_datasets(data_path, image_set, year, transforms):
ann_file = Path(data_path).joinpath('annotations').joinpath(f"instances_{image_set}{year}.json")
return CocoDetection(data_path, ann_file, transforms())
114 changes: 0 additions & 114 deletions datasets/pl_wrapper.py

This file was deleted.

59 changes: 26 additions & 33 deletions datasets/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,32 @@
from torchvision.ops.boxes import box_convert


def collate_fn(batch):
return tuple(zip(*batch))


def default_train_transforms():
scales = [384, 416, 448, 480, 512, 544, 576, 608, 640, 672]
scales_for_training = [(640, 640)]

return Compose([
RandomHorizontalFlip(),
RandomSelect(
RandomResize(scales_for_training),
Compose([
RandomResize(scales),
RandomSizeCrop(384, 480),
RandomResize(scales_for_training),
])
),
Compose([ToTensor(), Normalize()]),
])


def default_val_transforms():
return Compose([ToTensor(), Normalize()])


def crop(image, target, region):
cropped_image = F.crop(image, *region)

Expand Down Expand Up @@ -271,36 +297,3 @@ def __repr__(self):
format_string += " {0}".format(t)
format_string += "\n)"
return format_string


def make_transforms(image_set='train'):

normalize = Compose([
ToTensor(),
Normalize(),
])

scales = [384, 416, 448, 480, 512, 544, 576, 608, 640, 672]
scales_for_training = [(640, 640)]

if image_set == 'train' or image_set == 'trainval':
return Compose([
RandomHorizontalFlip(),
RandomSelect(
RandomResize(scales_for_training),
Compose([
RandomResize(scales),
RandomSizeCrop(384, 480),
RandomResize(scales_for_training),
])
),
normalize,
])

if image_set == 'val' or image_set == 'test':
return Compose([
RandomResize([800], max_size=1333),
normalize,
])

raise ValueError(f'unknown {image_set}')
Loading

0 comments on commit d756a41

Please sign in to comment.