Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Inherit Datasets from LightningDataModule #46

Merged
merged 5 commits into from
Feb 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 2 additions & 52 deletions datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,3 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import torch.utils.data
import torchvision
# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved.

from models.transform import nested_tensor_from_tensor_list

from .coco import build as build_coco
from .voc import build as build_voc


def get_coco_api_from_dataset(dataset):
for _ in range(10):
# if isinstance(dataset, torchvision.datasets.CocoDetection):
# break
if isinstance(dataset, torch.utils.data.Subset):
dataset = dataset.dataset
if isinstance(dataset, torchvision.datasets.CocoDetection):
return dataset.coco


def collate_fn(batch):
batch = list(zip(*batch))
samples = nested_tensor_from_tensor_list(batch[0])

targets = []
for i, target in enumerate(batch[1]):
num_objects = len(target['labels'])
if num_objects > 0:
targets_merged = torch.full((num_objects, 6), i, dtype=torch.float32)
targets_merged[:, 1] = target['labels']
targets_merged[:, 2:] = target['boxes']
targets.append(targets_merged)
targets = torch.cat(targets, dim=0)

return samples, targets


def build_dataset(image_set, dataset_year, args):

datasets = []
for year in dataset_year:
if args.dataset_file == 'coco':
dataset = build_coco(image_set, year, args)
elif args.dataset_file == 'voc':
dataset = build_voc(image_set, year, args)
else:
raise ValueError(f'dataset {args.dataset_file} not supported')
datasets.append(dataset)

if len(datasets) == 1:
return datasets[0]
else:
return torch.utils.data.ConcatDataset(datasets)
from .pl_wrapper import collate_fn, DetectionDataModule
14 changes: 4 additions & 10 deletions datasets/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __call__(self, image, target):


class CocoDetection(torchvision.datasets.CocoDetection):
def __init__(self, img_folder, ann_file, transforms, return_masks):
def __init__(self, img_folder, ann_file, transforms, return_masks=False):
super().__init__(img_folder, ann_file)
self._transforms = transforms

Expand Down Expand Up @@ -158,19 +158,13 @@ def _has_valid_annotation(anno):
return dataset


def build(image_set, year, args):
root = Path(args.data_path)
assert root.exists(), f'provided COCO path {root} does not exist'
mode = args.dataset_mode

img_folder = Path(root)
ann_file = img_folder.joinpath("annotations").joinpath(f"{mode}_{image_set}{year}.json")
def build(data_path, image_set, year):
ann_file = Path(data_path).joinpath("annotations").joinpath(f"instances_{image_set}{year}.json")

dataset = CocoDetection(
img_folder,
data_path,
ann_file,
transforms=make_transforms(image_set=image_set),
return_masks=args.masks,
)

if image_set == 'train':
Expand Down
114 changes: 114 additions & 0 deletions datasets/pl_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved.
import torch.utils.data
from torch.utils.data import DataLoader

from pytorch_lightning import LightningDataModule

from models.transform import nested_tensor_from_tensor_list
from .coco import build as build_coco
from .voc import build as build_voc

from typing import List, Any


def collate_fn(batch):
batch = list(zip(*batch))
samples = nested_tensor_from_tensor_list(batch[0])

targets = []
for i, target in enumerate(batch[1]):
num_objects = len(target['labels'])
if num_objects > 0:
targets_merged = torch.full((num_objects, 6), i, dtype=torch.float32)
targets_merged[:, 1] = target['labels']
targets_merged[:, 2:] = target['boxes']
targets.append(targets_merged)
targets = torch.cat(targets, dim=0)

return samples, targets


def build_dataset(data_path, dataset_type, image_set, dataset_year):

datasets = []
for year in dataset_year:
if dataset_type == 'coco':
dataset = build_coco(data_path, image_set, year)
elif dataset_type == 'voc':
dataset = build_voc(data_path, image_set, year)
else:
raise ValueError(f'dataset {dataset_type} not supported')
datasets.append(dataset)

if len(datasets) == 1:
return datasets[0]
else:
return torch.utils.data.ConcatDataset(datasets)


class DetectionDataModule(LightningDataModule):
"""
Wrapper of Datasets in LightningDataModule
"""
def __init__(
self,
data_path: str,
dataset_type: str,
dataset_year: List[str],
num_workers: int = 4,
*args: Any,
**kwargs: Any,
) -> None:
super().__init__(*args, **kwargs)

self.data_path = data_path
self.dataset_type = dataset_type
self.dataset_year = dataset_year
self.num_workers = num_workers

def train_dataloader(self, batch_size: int = 16) -> None:
"""
VOCDetection and CocoDetection
Args:
batch_size: size of batch
transforms: custom transforms
"""
dataset = build_dataset(self.data_path, self.dataset_type, 'train', self.dataset_year)

# Creating data loaders
sampler = torch.utils.data.RandomSampler(dataset)
batch_sampler = torch.utils.data.BatchSampler(
sampler, batch_size, drop_last=True,
)

loader = DataLoader(
dataset,
batch_sampler=batch_sampler,
collate_fn=collate_fn,
num_workers=self.num_workers,
)

return loader

def val_dataloader(self, batch_size: int = 16) -> None:
"""
VOCDetection and CocoDetection
Args:
batch_size: size of batch
transforms: custom transforms
"""
dataset = build_dataset(self.data_path, self.dataset_type, 'val', self.dataset_year)

# Creating data loaders
sampler = torch.utils.data.SequentialSampler(dataset)

loader = DataLoader(
dataset,
batch_size,
sampler=sampler,
drop_last=False,
collate_fn=collate_fn,
num_workers=self.num_workers,
)

return loader
4 changes: 2 additions & 2 deletions datasets/voc.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,10 @@ def __getitem__(self, index):
return img, target


def build(image_set, year, args):
def build(data_path, image_set, year):

dataset = VOCDetection(
img_folder=args.data_path,
data_path,
year=year,
image_set=image_set,
transforms=make_transforms(image_set=image_set),
Expand Down
45 changes: 4 additions & 41 deletions main.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,11 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
# Modified by Zhiqiang Wang ([email protected])

import datetime
import argparse
import time
from pathlib import Path

import torch
from torch.utils.data import DataLoader, DistributedSampler

import pytorch_lightning as pl

from datasets import build_dataset, get_coco_api_from_dataset, collate_fn
from datasets import DetectionDataModule
from models import YOLOLitWrapper


Expand All @@ -20,7 +14,7 @@ def get_args_parser():

parser.add_argument('--data_path', default='./data-bin',
help='dataset')
parser.add_argument('--dataset_file', default='coco',
parser.add_argument('--dataset_type', default='coco',
help='dataset')
parser.add_argument('--dataset_mode', default='instances',
help='dataset mode')
Expand All @@ -45,46 +39,15 @@ def get_args_parser():

def main(args):

# Data loading code
print('Loading data')
dataset_train = build_dataset(args.train_set, args.dataset_year, args)
dataset_val = build_dataset(args.val_set, args.dataset_year, args)
base_ds = get_coco_api_from_dataset(dataset_val)

print('Creating data loaders')
sampler_train = torch.utils.data.RandomSampler(dataset_train)
sampler_val = torch.utils.data.SequentialSampler(dataset_val)

batch_sampler_train = torch.utils.data.BatchSampler(
sampler_train, args.batch_size, drop_last=True,
)

data_loader_train = DataLoader(
dataset_train,
batch_sampler=batch_sampler_train,
collate_fn=collate_fn,
num_workers=args.num_workers,
)
data_loader_val = DataLoader(
dataset_val,
args.batch_size,
sampler=sampler_val,
drop_last=False,
collate_fn=collate_fn,
num_workers=args.num_workers,
)

print('Creating model, always set args.return_criterion be True')
args.return_criterion = True

# Load model
model = YOLOLitWrapper()
model.train()
datamodule = DetectionDataModule.from_argparse_args(args)

# train
# trainer = pl.Trainer().from_argparse_args(args)
trainer = pl.Trainer(max_epochs=1, gpus=1)
trainer.fit(model, data_loader_train, data_loader_val)
trainer.fit(model, datamodule=datamodule)


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# Copyright (c) 2020, Zhiqiang Wang. All Rights Reserved.
# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved.
from torch import nn
from .common import Conv
from utils.activations import Hardswish

from .lightning_wrapper import YOLOLitWrapper
from .pl_wrapper import YOLOLitWrapper


def yolov5s(**kwargs):
Expand Down
2 changes: 1 addition & 1 deletion models/lightning_wrapper.py → models/pl_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from . import yolo
from .transform import nested_tensor_from_tensor_list

from typing import Tuple, Any, List, Dict, Optional
from typing import Any, List, Optional


class YOLOLitWrapper(pl.LightningModule):
Expand Down
1 change: 1 addition & 0 deletions test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from .torch_utils import image_preprocess
from .dataset_utils import create_loaders, DummyDetectionDataset

from models import YOLOLitWrapper

from typing import Dict
Expand Down