Skip to content

Commit

Permalink
Refactoring VOCDetectionDataModule
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqwang committed Feb 7, 2021
1 parent 70ae932 commit 18fa51f
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 96 deletions.
2 changes: 1 addition & 1 deletion datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved.

from .detection_datamodule import collate_fn, DetectionDataModule
from .voc import VOCDetectionDataModule
72 changes: 37 additions & 35 deletions datasets/detection_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved.
import torch.utils.data
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset
import torchvision.transforms as T

from pytorch_lightning import LightningDataModule

from . import transforms as FT
from models.transform import nested_tensor_from_tensor_list
from .coco import build as build_coco
from .voc import build as build_voc

from typing import List, Any
from typing import List, Any, Optional


def collate_fn(batch):
Expand All @@ -28,22 +29,26 @@ def collate_fn(batch):
return samples, targets


def build_dataset(data_path, dataset_type, image_set, dataset_year):
def default_train_transforms():
scales = [384, 416, 448, 480, 512, 544, 576, 608, 640, 672]
scales_for_training = [(640, 640)]

datasets = []
for year in dataset_year:
if dataset_type == 'coco':
dataset = build_coco(data_path, image_set, year)
elif dataset_type == 'voc':
dataset = build_voc(data_path, image_set, year)
else:
raise ValueError(f'dataset {dataset_type} not supported')
datasets.append(dataset)
return T.Compose([
FT.RandomHorizontalFlip(),
FT.RandomSelect(
FT.RandomResize(scales_for_training),
T.Compose([
FT.RandomResize(scales),
FT.RandomSizeCrop(384, 480),
FT.RandomResize(scales_for_training),
])
),
FT.Compose([FT.ToTensor(), FT.Normalize()]),
])

if len(datasets) == 1:
return datasets[0]
else:
return torch.utils.data.ConcatDataset(datasets)

def default_val_transforms():
return T.Compose([FT.Compose([FT.ToTensor(), FT.Normalize()])])


class DetectionDataModule(LightningDataModule):
Expand All @@ -52,18 +57,21 @@ class DetectionDataModule(LightningDataModule):
"""
def __init__(
self,
data_path: str,
dataset_type: str,
dataset_year: List[str],
num_workers: int = 4,
train_dataset: Optional[Dataset] = None,
val_dataset: Optional[Dataset] = None,
test_dataset: Optional[Dataset] = None,
batch_size: int = 1,
num_workers: int = 0,
*args: Any,
**kwargs: Any,
) -> None:
super().__init__(*args, **kwargs)

self.data_path = data_path
self.dataset_type = dataset_type
self.dataset_year = dataset_year
self._train_dataset = train_dataset
self._val_dataset = val_dataset
self._test_dataset = test_dataset

self.batch_size = batch_size
self.num_workers = num_workers

def train_dataloader(self, batch_size: int = 16) -> None:
Expand All @@ -73,16 +81,12 @@ def train_dataloader(self, batch_size: int = 16) -> None:
batch_size: size of batch
transforms: custom transforms
"""
dataset = build_dataset(self.data_path, self.dataset_type, 'train', self.dataset_year)

# Creating data loaders
sampler = torch.utils.data.RandomSampler(dataset)
batch_sampler = torch.utils.data.BatchSampler(
sampler, batch_size, drop_last=True,
)
sampler = torch.utils.data.RandomSampler(self._train_dataset)
batch_sampler = torch.utils.data.BatchSampler(sampler, batch_size, drop_last=True)

loader = DataLoader(
dataset,
self._train_dataset,
batch_sampler=batch_sampler,
collate_fn=collate_fn,
num_workers=self.num_workers,
Expand All @@ -97,13 +101,11 @@ def val_dataloader(self, batch_size: int = 16) -> None:
batch_size: size of batch
transforms: custom transforms
"""
dataset = build_dataset(self.data_path, self.dataset_type, 'val', self.dataset_year)

# Creating data loaders
sampler = torch.utils.data.SequentialSampler(dataset)
sampler = torch.utils.data.SequentialSampler(self._val_dataset)

loader = DataLoader(
dataset,
self._val_dataset,
batch_size,
sampler=sampler,
drop_last=False,
Expand Down
33 changes: 0 additions & 33 deletions datasets/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,36 +271,3 @@ def __repr__(self):
format_string += " {0}".format(t)
format_string += "\n)"
return format_string


def make_transforms(image_set='train'):

normalize = Compose([
ToTensor(),
Normalize(),
])

scales = [384, 416, 448, 480, 512, 544, 576, 608, 640, 672]
scales_for_training = [(640, 640)]

if image_set == 'train' or image_set == 'trainval':
return Compose([
RandomHorizontalFlip(),
RandomSelect(
RandomResize(scales_for_training),
Compose([
RandomResize(scales),
RandomSizeCrop(384, 480),
RandomResize(scales_for_training),
])
),
normalize,
])

if image_set == 'val' or image_set == 'test':
return Compose([
RandomResize([800], max_size=1333),
normalize,
])

raise ValueError(f'unknown {image_set}')
62 changes: 49 additions & 13 deletions datasets/voc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,55 @@
import torch
import torchvision

from .transforms import make_transforms
from .detection_datamodule import (
DetectionDataModule,
default_train_transforms,
default_val_transforms,
)

from typing import Callable, List, Any, Optional


class VOCDetectionDataModule(DetectionDataModule):
def __init__(
self,
data_path: str,
years: List[str] = ["2007", "2012"],
train_transform: Optional[Callable] = default_train_transforms,
val_transform: Optional[Callable] = default_val_transforms,
batch_size: int = 1,
num_workers: int = 0,
*args: Any,
**kwargs: Any,
) -> None:
train_dataset, num_classes = self.build_datasets(
data_path, image_set='train', years=years, transforms=train_transform)
val_dataset, _ = self.build_datasets(
data_path, image_set='val', years=years, transforms=val_transform)

super().__init__(train_dataset=train_dataset, val_dataset=val_dataset,
batch_size=batch_size, num_workers=num_workers, *args, **kwargs)

self.num_classes = num_classes

@staticmethod
def build_datasets(data_path, image_set, years, transforms):
datasets = []
for year in years:
dataset = VOCDetection(
data_path,
year=year,
image_set=image_set,
transforms=transforms,
)
datasets.append(dataset)

num_classes = len(datasets[0].prepare.CLASSES)

if len(datasets) == 1:
return datasets[0], num_classes
else:
return torch.utils.data.ConcatDataset(datasets), num_classes


class ConvertVOCtoCOCO(object):
Expand Down Expand Up @@ -70,15 +118,3 @@ def __getitem__(self, index):
img, target = self._transforms(img, target)

return img, target


def build(data_path, image_set, year):

dataset = VOCDetection(
data_path,
year=year,
image_set=image_set,
transforms=make_transforms(image_set=image_set),
)

return dataset
8 changes: 5 additions & 3 deletions models/pl_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@
import pytorch_lightning as pl

from . import yolo
from .transform import GeneralizedYOLOTransform, nested_tensor_from_tensor_list
from .transform import GeneralizedYOLOTransform

from typing import Any, List, Dict, Tuple, Optional


class YOLOLitWrapper(pl.LightningModule):
"""
PyTorch Lightning implementation of `YOLO`
PyTorch Lightning wrapper of `YOLO`
"""
def __init__(
self,
Expand Down Expand Up @@ -78,7 +78,9 @@ def forward(
return detections

def training_step(self, batch, batch_idx):

"""
The training step.
"""
samples, targets = batch

# yolov5 takes both images and targets for training, returns
Expand Down
28 changes: 17 additions & 11 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,29 +5,33 @@

import pytorch_lightning as pl

from datasets import DetectionDataModule
from models import YOLOLitWrapper
from datasets import VOCDetectionDataModule
import models


def get_args_parser():
parser = argparse.ArgumentParser('You only look once detector', add_help=False)

parser.add_argument('--arch', default='yolov5s',
help='model structure to train')
parser.add_argument('--data_path', default='./data-bin',
help='dataset')
parser.add_argument('--dataset_type', default='coco',
help='dataset')
parser.add_argument('--dataset_mode', default='instances',
help='dataset mode')
parser.add_argument('--dataset_year', default=['2017'], nargs='+',
parser.add_argument('--years', default=['2017'], nargs='+',
help='dataset year')
parser.add_argument('--train_set', default='train',
help='set of train')
parser.add_argument('--val_set', default='val',
help='set of val')
parser.add_argument('--batch_size', default=32, type=int,
help='images per gpu, the total batch size is $NGPU x batch_size')
parser.add_argument('--epochs', default=26, type=int, metavar='N',
parser.add_argument('--max_epochs', default=1, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('--num_gpus', default=1, type=int, metavar='N',
help='number of gpu utilizing (default: 1)')
parser.add_argument('--num_workers', default=4, type=int, metavar='N',
help='number of data loading workers (default: 4)')
parser.add_argument('--print_freq', default=20, type=int,
Expand All @@ -38,15 +42,16 @@ def get_args_parser():


def main(args):
# Load the data
datamodule = VOCDetectionDataModule.from_argparse_args(args)

# Load model
model = YOLOLitWrapper()
model.train()
datamodule = DetectionDataModule.from_argparse_args(args)
# Build the model
model = models.__dict__[args.arch](num_classes=datamodule.num_classes)

# train
# trainer = pl.Trainer().from_argparse_args(args)
trainer = pl.Trainer(max_epochs=1, gpus=1)
# Create the trainer. Run twice on data
trainer = pl.Trainer(max_epochs=args.max_epochs, gpus=args.num_gpus)

# Train the model
trainer.fit(model, datamodule=datamodule)


Expand All @@ -55,4 +60,5 @@ def main(args):
args = parser.parse_args()
if args.output_dir:
Path(args.output_dir).mkdir(parents=True, exist_ok=True)

main(args)

0 comments on commit 18fa51f

Please sign in to comment.