Skip to content

Commit

Permalink
Refactoring VOCDetectionDataModule
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqwang committed Feb 4, 2021
1 parent 70ae932 commit 1568dad
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 49 deletions.
2 changes: 1 addition & 1 deletion datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved.

from .detection_datamodule import collate_fn, DetectionDataModule
from .voc import VOCDetectionDataModule
56 changes: 17 additions & 39 deletions datasets/detection_datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@
from torch.utils.data import DataLoader

from pytorch_lightning import LightningDataModule
from torch.utils.data.dataset import Dataset

from models.transform import nested_tensor_from_tensor_list
from .coco import build as build_coco
from .voc import build as build_voc

from typing import List, Any
from typing import List, Any, Optional


def collate_fn(batch):
Expand All @@ -28,42 +27,27 @@ def collate_fn(batch):
return samples, targets


def build_dataset(data_path, dataset_type, image_set, dataset_year):

datasets = []
for year in dataset_year:
if dataset_type == 'coco':
dataset = build_coco(data_path, image_set, year)
elif dataset_type == 'voc':
dataset = build_voc(data_path, image_set, year)
else:
raise ValueError(f'dataset {dataset_type} not supported')
datasets.append(dataset)

if len(datasets) == 1:
return datasets[0]
else:
return torch.utils.data.ConcatDataset(datasets)


class DetectionDataModule(LightningDataModule):
"""
Wrapper of Datasets in LightningDataModule
"""
def __init__(
self,
data_path: str,
dataset_type: str,
dataset_year: List[str],
num_workers: int = 4,
train_dataset: Optional[Dataset] = None,
val_dataset: Optional[Dataset] = None,
test_dataset: Optional[Dataset] = None,
batch_size: int = 1,
num_workers: int = 0,
*args: Any,
**kwargs: Any,
) -> None:
super().__init__(*args, **kwargs)

self.data_path = data_path
self.dataset_type = dataset_type
self.dataset_year = dataset_year
self._train_dataset = train_dataset
self._val_dataset = val_dataset
self._test_dataset = test_dataset

self.batch_size = batch_size
self.num_workers = num_workers

def train_dataloader(self, batch_size: int = 16) -> None:
Expand All @@ -73,16 +57,12 @@ def train_dataloader(self, batch_size: int = 16) -> None:
batch_size: size of batch
transforms: custom transforms
"""
dataset = build_dataset(self.data_path, self.dataset_type, 'train', self.dataset_year)

# Creating data loaders
sampler = torch.utils.data.RandomSampler(dataset)
batch_sampler = torch.utils.data.BatchSampler(
sampler, batch_size, drop_last=True,
)
sampler = torch.utils.data.RandomSampler(self._train_dataset)
batch_sampler = torch.utils.data.BatchSampler(sampler, batch_size, drop_last=True)

loader = DataLoader(
dataset,
self._train_dataset,
batch_sampler=batch_sampler,
collate_fn=collate_fn,
num_workers=self.num_workers,
Expand All @@ -97,13 +77,11 @@ def val_dataloader(self, batch_size: int = 16) -> None:
batch_size: size of batch
transforms: custom transforms
"""
dataset = build_dataset(self.data_path, self.dataset_type, 'val', self.dataset_year)

# Creating data loaders
sampler = torch.utils.data.SequentialSampler(dataset)
sampler = torch.utils.data.SequentialSampler(self._val_dataset)

loader = DataLoader(
dataset,
self._val_dataset,
batch_size,
sampler=sampler,
drop_last=False,
Expand Down
20 changes: 20 additions & 0 deletions datasets/voc.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,28 @@
import torch
import torchvision

from .detection_datamodule import DetectionDataModule
from .transforms import make_transforms

from typing import List, Any


class VOCDetectionDataModule(DetectionDataModule):
def __init__(
self,
num_workers: int,
*args: Any,
**kwargs: Any,
) -> None:
super().__init__(num_workers=num_workers, *args, **kwargs)

def prepare_data(self) -> None:
"""
Saves VOCDetection files to data_dir
"""
VOCDetection(self.data_dir, year=self.year, image_set="train", download=True)
VOCDetection(self.data_dir, year=self.year, image_set="val", download=True)


class ConvertVOCtoCOCO(object):

Expand Down
2 changes: 1 addition & 1 deletion models/pl_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pytorch_lightning as pl

from . import yolo
from .transform import GeneralizedYOLOTransform, nested_tensor_from_tensor_list
from .transform import GeneralizedYOLOTransform

from typing import Any, List, Dict, Tuple, Optional

Expand Down
21 changes: 13 additions & 8 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import pytorch_lightning as pl

from datasets import DetectionDataModule
from datasets import VOCDetectionDataModule
from models import YOLOLitWrapper


Expand Down Expand Up @@ -38,15 +38,19 @@ def get_args_parser():


def main(args):
# Load the data
datamodule = VOCDetectionDataModule.from_folders(
train_folder="data-bin/coco128/train/",
valid_folder="data-bin/coco128/val/",
)

# Load model
model = YOLOLitWrapper()
model.train()
datamodule = DetectionDataModule.from_argparse_args(args)
# Build the model
model = YOLOLitWrapper(arch="yolov5_darknet_pan_s_r31", num_classes=datamodule.num_classes)

# train
# trainer = pl.Trainer().from_argparse_args(args)
trainer = pl.Trainer(max_epochs=1, gpus=1)
# Create the trainer. Run twice on data
trainer = pl.Trainer(max_epochs=2, gpus=1)

# Train the model
trainer.fit(model, datamodule=datamodule)


Expand All @@ -55,4 +59,5 @@ def main(args):
args = parser.parse_args()
if args.output_dir:
Path(args.output_dir).mkdir(parents=True, exist_ok=True)

main(args)

0 comments on commit 1568dad

Please sign in to comment.