Skip to content

Commit

Permalink
Fix loading with coco and voc datasets (#12)
Browse files Browse the repository at this point in the history
* Remove pretrained parameters of backbone

* Add loss computating scripts

* Add common gitattributes

* Move test images to test directory

* Add EngineTester

* Rename to model

* Add loss computation tools

* Move FocalLoss to utils

* Minor fixes

* Minor fixes

* Fixes loss computation

* Fixes anchors targets matcher

* Add model utils unittest

* Fix datasets loading

* Fix data transforms and loaders

* Split engine unittest

* Code cleanup

* Ignore loss computation
  • Loading branch information
zhiqwang authored Dec 15, 2020
1 parent 34aaaad commit 20291c9
Show file tree
Hide file tree
Showing 23 changed files with 475 additions and 319 deletions.
31 changes: 31 additions & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
@@ -1,2 +1,33 @@
# this drop notebooks from GitHub language stats
*.ipynb linguist-documentation

# Graphics
*.png binary
*.jpg binary
*.jpeg binary
*.gif binary
*.tif binary
*.tiff binary
*.ico binary
# SVG treated as an asset (binary) by default.
*.svg text
# If you want to treat it as binary,
# use the following line instead.
# *.svg binary
*.eps binary

# Scripts
*.bash text eol=lf
*.fish text eol=lf
*.sh text eol=lf
# These are explicitly windows files and should use crlf
*.bat text eol=crlf
*.cmd text eol=crlf
*.ps1 text eol=crlf

# Serialisation
*.json text
*.toml text
*.xml text
*.yaml text
*.yml text
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ The module state of `yolov5rt` has some differences comparing to `ultralytics/yo
To read a source image and detect its objects run:

```bash
python -m detect [--input_source YOUR_IMAGE_SOURCE_DIR]
python -m detect [--input_source ./test/assets/zidane.jpg]
[--labelmap ./notebooks/assets/coco.names]
[--output_dir ./data-bin/output]
[--min_size 640]
Expand Down
33 changes: 2 additions & 31 deletions datasets/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import torchvision
from pycocotools import mask as coco_mask

from . import transforms as T
from .transforms import make_transforms


class ConvertCocoPolysToMask(object):
Expand Down Expand Up @@ -153,35 +153,6 @@ def _has_valid_annotation(anno):
return dataset


def make_coco_transforms(image_set, image_size=300):

normalize = T.Compose([
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

if image_set == 'train' or image_set == 'trainval':
return T.Compose([
T.RandomHorizontalFlip(),
T.RandomSelect(
T.Resize(image_size),
T.Compose([
T.RandomResize([400, 500, 600]),
T.RandomSizeCrop(384, 600),
T.Resize(image_size),
])
),
normalize,
])
elif image_set == 'val' or image_set == 'test':
return T.Compose([
T.Resize(image_size),
normalize,
])
else:
raise ValueError(f'unknown {image_set}')


def build(image_set, year, args):
root = Path(args.data_path)
assert root.exists(), f'provided COCO path {root} does not exist'
Expand All @@ -194,7 +165,7 @@ def build(image_set, year, args):
dataset = CocoDetection(
img_folder,
ann_file,
transforms=make_coco_transforms(image_set, image_size=args.image_size),
transforms=make_transforms(image_set=image_set),
return_masks=args.masks,
)

Expand Down
2 changes: 1 addition & 1 deletion datasets/coco_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from pycocotools.cocoeval import COCOeval
from pycocotools.coco import COCO

from util.misc import all_gather
from utils.misc import all_gather


class CocoEvaluator(object):
Expand Down
66 changes: 40 additions & 26 deletions datasets/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torchvision.transforms as T
import torchvision.transforms.functional as F

from util.misc import interpolate
from torchvision.ops.boxes import box_convert


def crop(image, target, region):
Expand Down Expand Up @@ -125,7 +125,7 @@ def get_size(image_size, size, max_size=None):
target["size"] = torch.tensor([h, w])

if "masks" in target:
target['masks'] = interpolate(
target['masks'] = torch.nn.functional.interpolate(
target['masks'][:, None].float(), size, mode="nearest")[:, 0] > 0.5

return rescaled_image, target
Expand All @@ -138,7 +138,7 @@ def pad(image, target, padding):
return padded_image, None
target = target.copy()
# should we do something wrt the original size?
target["size"] = torch.tensor(padded_image[::-1])
target["size"] = torch.tensor(padded_image.size[::-1])
if "masks" in target:
target['masks'] = torch.nn.functional.pad(target['masks'], (0, padding[0], 0, padding[1]))
return padded_image, target
Expand All @@ -159,14 +159,10 @@ def __init__(self, min_size: int, max_size: int):
self.max_size = max_size

def __call__(self, img: PIL.Image.Image, target: dict):
while True:
w = random.randint(self.min_size, min(img.width, self.max_size))
h = random.randint(self.min_size, min(img.height, self.max_size))
region = T.RandomCrop.get_params(img, [h, w])

croped_img, croped_target = crop(img, target, region)
if len(croped_target['labels']) > 0:
return croped_img, croped_target
w = random.randint(self.min_size, min(img.width, self.max_size))
h = random.randint(self.min_size, min(img.height, self.max_size))
region = T.RandomCrop.get_params(img, [h, w])
return crop(img, target, region)


class CenterCrop(object):
Expand Down Expand Up @@ -202,20 +198,6 @@ def __call__(self, img, target=None):
return resize(img, target, size, self.max_size)


class Resize(object):
def __init__(self, size):
if isinstance(size, tuple):
sizes = size
assert len(sizes) == 2, "The length of sizes must be 2"
else:
sizes = (size, size)

self.sizes = sizes

def __call__(self, img, target=None):
return resize(img, target, self.sizes)


class RandomPad(object):
def __init__(self, max_pad):
self.max_pad = max_pad
Expand Down Expand Up @@ -269,7 +251,7 @@ def __call__(self, image, target=None):
h, w = image.shape[-2:]
if "boxes" in target:
boxes = target["boxes"]
# converted to XYXY_REL BoxMode
boxes = box_convert(boxes, in_fmt='xyxy', out_fmt='cxcywh')
boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32)
target["boxes"] = boxes
return image, target
Expand All @@ -291,3 +273,35 @@ def __repr__(self):
format_string += " {0}".format(t)
format_string += "\n)"
return format_string


def make_transforms(image_set='train'):

normalize = Compose([
ToTensor(),
Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800]

if image_set == 'train' or image_set == 'trainval':
return Compose([
RandomHorizontalFlip(),
RandomSelect(
RandomResize(scales, max_size=1333),
Compose([
RandomResize([400, 500, 600]),
RandomSizeCrop(384, 600),
RandomResize(scales, max_size=1333),
])
),
normalize,
])

if image_set == 'val' or image_set == 'test':
return Compose([
RandomResize([800], max_size=1333),
normalize,
])

raise ValueError(f'unknown {image_set}')
36 changes: 2 additions & 34 deletions datasets/voc.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import torch
import torchvision

from . import transforms as T
from .transforms import make_transforms


class ConvertVOCtoCOCO(object):
Expand Down Expand Up @@ -73,45 +73,13 @@ def __getitem__(self, index):
return img, target


def make_voc_transforms(image_set='train', image_size=300):

normalize = T.Compose([
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

if image_set == 'train' or image_set == 'trainval':
return T.Compose([
T.RandomHorizontalFlip(),
T.RandomSelect(
T.Resize(image_size),
T.Compose([
T.RandomResize([400, 500, 600]),
T.RandomSizeCrop(384, 600),
T.Resize(image_size),
])
),
normalize,
])
elif image_set == 'val' or image_set == 'test':
return T.Compose([
T.Resize(image_size),
normalize,
])
else:
raise ValueError(f'unknown {image_set}')


def build(image_set, year, args):

dataset = VOCDetection(
img_folder=args.data_path,
year=year,
image_set=image_set,
transforms=make_voc_transforms(
image_set=image_set,
image_size=args.image_size,
),
transforms=make_transforms(image_set=image_set),
)

return dataset
2 changes: 1 addition & 1 deletion detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def main(args):

parser.add_argument('--labelmap', type=str, default='./notebooks/assets/coco.names',
help='path where the coco category in')
parser.add_argument('--input_source', type=str, default='./notebooks/assets/zidane.jpg',
parser.add_argument('--input_source', type=str, default='./test/assets/zidane.jpg',
help='path where the source images in')
parser.add_argument('--output_dir', type=str, default='./data-bin/output',
help='path where to save')
Expand Down
19 changes: 9 additions & 10 deletions engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
import utils.misc as utils


def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, print_freq):
def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq):
model.train()
metric_logger = utils.MetricLogger(delimiter=' ')
metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}'))
Expand All @@ -24,12 +24,11 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, pri

lr_scheduler = utils.warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor)

for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
samples = samples.to(device)
for images, targets in metric_logger.log_every(data_loader, print_freq, header):
images = list(image.to(device) for image in images)
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]

outputs = model(samples)
loss_dict = criterion(outputs, targets)
loss_dict = model(images, targets)

losses = sum(loss for loss in loss_dict.values())

Expand Down Expand Up @@ -70,20 +69,20 @@ def _get_iou_types(model):


@torch.no_grad()
def evaluate(model, criterion, data_loader, base_ds, device):
def evaluate(model, data_loader, base_ds, device):
model.eval()
metric_logger = utils.MetricLogger(delimiter=' ')
header = 'Test:'

iou_types = _get_iou_types(model)
coco_evaluator = CocoEvaluator(base_ds, iou_types)

for samples, targets in metric_logger.log_every(data_loader, 20, header):
samples = samples.to(device)
for images, targets in metric_logger.log_every(data_loader, 20, header):
images = images.to(device)

model_time = time.time()
target_sizes = torch.stack([t['orig_size'] for t in targets], dim=0).to(device)
results = model(samples, target_sizes=target_sizes)
results = model(images, target_sizes=target_sizes)

model_time = time.time() - model_time

Expand All @@ -98,7 +97,7 @@ def evaluate(model, criterion, data_loader, base_ds, device):
print(f'Averaged stats: {metric_logger}')
coco_evaluator.synchronize_between_processes()

# accumulate predictions from all samples
# accumulate predictions from all images
coco_evaluator.accumulate()
coco_evaluator.summarize()
return coco_evaluator
Loading

0 comments on commit 20291c9

Please sign in to comment.