From 20291c95822c0ea2dd8fb3bbb738d41f1809503f Mon Sep 17 00:00:00 2001 From: Zhiqiang Wang Date: Wed, 16 Dec 2020 00:39:22 +0800 Subject: [PATCH] Fix loading with coco and voc datasets (#12) * Remove pretrained parameters of backbone * Add loss computating scripts * Add common gitattributes * Move test images to test directory * Add EngineTester * Rename to model * Add loss computation tools * Move FocalLoss to utils * Minor fixes * Minor fixes * Fixes loss computation * Fixes anchors targets matcher * Add model utils unittest * Fix datasets loading * Fix data transforms and loaders * Split engine unittest * Code cleanup * Ignore loss computation --- .gitattributes | 31 +++ README.md | 2 +- datasets/coco.py | 33 +-- datasets/coco_eval.py | 2 +- datasets/transforms.py | 66 ++++-- datasets/voc.py | 36 +-- detect.py | 2 +- engine.py | 19 +- main.py | 18 +- models/_utils.py | 32 ++- models/backbone.py | 2 +- models/box_head.py | 223 ++++++++++++++---- models/yolo.py | 26 +- .../export-onnx-inference-onnxruntime.ipynb | 4 +- .../inference-pytorch-export-libtorch.ipynb | 4 +- {notebooks => test}/assets/bus.jpg | Bin {notebooks => test}/assets/zidane.jpg | Bin test/test_engine.py | 52 ++++ test/test_models.py | 5 + test/test_models_utils.py | 34 +++ test/torch_utils.py | 11 + utils/box_ops.py | 48 ++++ utils/misc.py | 144 +---------- 23 files changed, 475 insertions(+), 319 deletions(-) rename {notebooks => test}/assets/bus.jpg (100%) rename {notebooks => test}/assets/zidane.jpg (100%) create mode 100644 test/test_engine.py create mode 100644 test/test_models_utils.py create mode 100644 test/torch_utils.py create mode 100644 utils/box_ops.py diff --git a/.gitattributes b/.gitattributes index 9f846e5b..92fdaff3 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,2 +1,33 @@ # this drop notebooks from GitHub language stats *.ipynb linguist-documentation + +# Graphics +*.png binary +*.jpg binary +*.jpeg binary +*.gif binary +*.tif binary +*.tiff binary +*.ico binary +# SVG treated as an asset (binary) by default. +*.svg text +# If you want to treat it as binary, +# use the following line instead. +# *.svg binary +*.eps binary + +# Scripts +*.bash text eol=lf +*.fish text eol=lf +*.sh text eol=lf +# These are explicitly windows files and should use crlf +*.bat text eol=crlf +*.cmd text eol=crlf +*.ps1 text eol=crlf + +# Serialisation +*.json text +*.toml text +*.xml text +*.yaml text +*.yml text diff --git a/README.md b/README.md index 4fd11fb8..6d54e228 100644 --- a/README.md +++ b/README.md @@ -100,7 +100,7 @@ The module state of `yolov5rt` has some differences comparing to `ultralytics/yo To read a source image and detect its objects run: ```bash -python -m detect [--input_source YOUR_IMAGE_SOURCE_DIR] +python -m detect [--input_source ./test/assets/zidane.jpg] [--labelmap ./notebooks/assets/coco.names] [--output_dir ./data-bin/output] [--min_size 640] diff --git a/datasets/coco.py b/datasets/coco.py index e64f58bc..4fc40ef8 100644 --- a/datasets/coco.py +++ b/datasets/coco.py @@ -11,7 +11,7 @@ import torchvision from pycocotools import mask as coco_mask -from . import transforms as T +from .transforms import make_transforms class ConvertCocoPolysToMask(object): @@ -153,35 +153,6 @@ def _has_valid_annotation(anno): return dataset -def make_coco_transforms(image_set, image_size=300): - - normalize = T.Compose([ - T.ToTensor(), - T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), - ]) - - if image_set == 'train' or image_set == 'trainval': - return T.Compose([ - T.RandomHorizontalFlip(), - T.RandomSelect( - T.Resize(image_size), - T.Compose([ - T.RandomResize([400, 500, 600]), - T.RandomSizeCrop(384, 600), - T.Resize(image_size), - ]) - ), - normalize, - ]) - elif image_set == 'val' or image_set == 'test': - return T.Compose([ - T.Resize(image_size), - normalize, - ]) - else: - raise ValueError(f'unknown {image_set}') - - def build(image_set, year, args): root = Path(args.data_path) assert root.exists(), f'provided COCO path {root} does not exist' @@ -194,7 +165,7 @@ def build(image_set, year, args): dataset = CocoDetection( img_folder, ann_file, - transforms=make_coco_transforms(image_set, image_size=args.image_size), + transforms=make_transforms(image_set=image_set), return_masks=args.masks, ) diff --git a/datasets/coco_eval.py b/datasets/coco_eval.py index d0ead3bf..1b8bfa1f 100644 --- a/datasets/coco_eval.py +++ b/datasets/coco_eval.py @@ -17,7 +17,7 @@ from pycocotools.cocoeval import COCOeval from pycocotools.coco import COCO -from util.misc import all_gather +from utils.misc import all_gather class CocoEvaluator(object): diff --git a/datasets/transforms.py b/datasets/transforms.py index d8366618..bc14a4f8 100644 --- a/datasets/transforms.py +++ b/datasets/transforms.py @@ -9,7 +9,7 @@ import torchvision.transforms as T import torchvision.transforms.functional as F -from util.misc import interpolate +from torchvision.ops.boxes import box_convert def crop(image, target, region): @@ -125,7 +125,7 @@ def get_size(image_size, size, max_size=None): target["size"] = torch.tensor([h, w]) if "masks" in target: - target['masks'] = interpolate( + target['masks'] = torch.nn.functional.interpolate( target['masks'][:, None].float(), size, mode="nearest")[:, 0] > 0.5 return rescaled_image, target @@ -138,7 +138,7 @@ def pad(image, target, padding): return padded_image, None target = target.copy() # should we do something wrt the original size? - target["size"] = torch.tensor(padded_image[::-1]) + target["size"] = torch.tensor(padded_image.size[::-1]) if "masks" in target: target['masks'] = torch.nn.functional.pad(target['masks'], (0, padding[0], 0, padding[1])) return padded_image, target @@ -159,14 +159,10 @@ def __init__(self, min_size: int, max_size: int): self.max_size = max_size def __call__(self, img: PIL.Image.Image, target: dict): - while True: - w = random.randint(self.min_size, min(img.width, self.max_size)) - h = random.randint(self.min_size, min(img.height, self.max_size)) - region = T.RandomCrop.get_params(img, [h, w]) - - croped_img, croped_target = crop(img, target, region) - if len(croped_target['labels']) > 0: - return croped_img, croped_target + w = random.randint(self.min_size, min(img.width, self.max_size)) + h = random.randint(self.min_size, min(img.height, self.max_size)) + region = T.RandomCrop.get_params(img, [h, w]) + return crop(img, target, region) class CenterCrop(object): @@ -202,20 +198,6 @@ def __call__(self, img, target=None): return resize(img, target, size, self.max_size) -class Resize(object): - def __init__(self, size): - if isinstance(size, tuple): - sizes = size - assert len(sizes) == 2, "The length of sizes must be 2" - else: - sizes = (size, size) - - self.sizes = sizes - - def __call__(self, img, target=None): - return resize(img, target, self.sizes) - - class RandomPad(object): def __init__(self, max_pad): self.max_pad = max_pad @@ -269,7 +251,7 @@ def __call__(self, image, target=None): h, w = image.shape[-2:] if "boxes" in target: boxes = target["boxes"] - # converted to XYXY_REL BoxMode + boxes = box_convert(boxes, in_fmt='xyxy', out_fmt='cxcywh') boxes = boxes / torch.tensor([w, h, w, h], dtype=torch.float32) target["boxes"] = boxes return image, target @@ -291,3 +273,35 @@ def __repr__(self): format_string += " {0}".format(t) format_string += "\n)" return format_string + + +def make_transforms(image_set='train'): + + normalize = Compose([ + ToTensor(), + Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), + ]) + + scales = [480, 512, 544, 576, 608, 640, 672, 704, 736, 768, 800] + + if image_set == 'train' or image_set == 'trainval': + return Compose([ + RandomHorizontalFlip(), + RandomSelect( + RandomResize(scales, max_size=1333), + Compose([ + RandomResize([400, 500, 600]), + RandomSizeCrop(384, 600), + RandomResize(scales, max_size=1333), + ]) + ), + normalize, + ]) + + if image_set == 'val' or image_set == 'test': + return Compose([ + RandomResize([800], max_size=1333), + normalize, + ]) + + raise ValueError(f'unknown {image_set}') diff --git a/datasets/voc.py b/datasets/voc.py index b36ed6fe..ab7b062a 100644 --- a/datasets/voc.py +++ b/datasets/voc.py @@ -1,7 +1,7 @@ import torch import torchvision -from . import transforms as T +from .transforms import make_transforms class ConvertVOCtoCOCO(object): @@ -73,45 +73,13 @@ def __getitem__(self, index): return img, target -def make_voc_transforms(image_set='train', image_size=300): - - normalize = T.Compose([ - T.ToTensor(), - T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), - ]) - - if image_set == 'train' or image_set == 'trainval': - return T.Compose([ - T.RandomHorizontalFlip(), - T.RandomSelect( - T.Resize(image_size), - T.Compose([ - T.RandomResize([400, 500, 600]), - T.RandomSizeCrop(384, 600), - T.Resize(image_size), - ]) - ), - normalize, - ]) - elif image_set == 'val' or image_set == 'test': - return T.Compose([ - T.Resize(image_size), - normalize, - ]) - else: - raise ValueError(f'unknown {image_set}') - - def build(image_set, year, args): dataset = VOCDetection( img_folder=args.data_path, year=year, image_set=image_set, - transforms=make_voc_transforms( - image_set=image_set, - image_size=args.image_size, - ), + transforms=make_transforms(image_set=image_set), ) return dataset diff --git a/detect.py b/detect.py index 210ffb14..5f1a661d 100644 --- a/detect.py +++ b/detect.py @@ -134,7 +134,7 @@ def main(args): parser.add_argument('--labelmap', type=str, default='./notebooks/assets/coco.names', help='path where the coco category in') - parser.add_argument('--input_source', type=str, default='./notebooks/assets/zidane.jpg', + parser.add_argument('--input_source', type=str, default='./test/assets/zidane.jpg', help='path where the source images in') parser.add_argument('--output_dir', type=str, default='./data-bin/output', help='path where to save') diff --git a/engine.py b/engine.py index 3c21c3da..d12ec198 100644 --- a/engine.py +++ b/engine.py @@ -11,7 +11,7 @@ import utils.misc as utils -def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, print_freq): +def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq): model.train() metric_logger = utils.MetricLogger(delimiter=' ') metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) @@ -24,12 +24,11 @@ def train_one_epoch(model, criterion, optimizer, data_loader, device, epoch, pri lr_scheduler = utils.warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor) - for samples, targets in metric_logger.log_every(data_loader, print_freq, header): - samples = samples.to(device) + for images, targets in metric_logger.log_every(data_loader, print_freq, header): + images = list(image.to(device) for image in images) targets = [{k: v.to(device) for k, v in t.items()} for t in targets] - outputs = model(samples) - loss_dict = criterion(outputs, targets) + loss_dict = model(images, targets) losses = sum(loss for loss in loss_dict.values()) @@ -70,7 +69,7 @@ def _get_iou_types(model): @torch.no_grad() -def evaluate(model, criterion, data_loader, base_ds, device): +def evaluate(model, data_loader, base_ds, device): model.eval() metric_logger = utils.MetricLogger(delimiter=' ') header = 'Test:' @@ -78,12 +77,12 @@ def evaluate(model, criterion, data_loader, base_ds, device): iou_types = _get_iou_types(model) coco_evaluator = CocoEvaluator(base_ds, iou_types) - for samples, targets in metric_logger.log_every(data_loader, 20, header): - samples = samples.to(device) + for images, targets in metric_logger.log_every(data_loader, 20, header): + images = images.to(device) model_time = time.time() target_sizes = torch.stack([t['orig_size'] for t in targets], dim=0).to(device) - results = model(samples, target_sizes=target_sizes) + results = model(images, target_sizes=target_sizes) model_time = time.time() - model_time @@ -98,7 +97,7 @@ def evaluate(model, criterion, data_loader, base_ds, device): print(f'Averaged stats: {metric_logger}') coco_evaluator.synchronize_between_processes() - # accumulate predictions from all samples + # accumulate predictions from all images coco_evaluator.accumulate() coco_evaluator.summarize() return coco_evaluator diff --git a/main.py b/main.py index 0a2c46e7..a6d26cb2 100644 --- a/main.py +++ b/main.py @@ -2,7 +2,6 @@ # Modified by Zhiqiang Wang (zhiqwang@foxmail.com) import datetime -import os import argparse import time from pathlib import Path @@ -13,14 +12,14 @@ import utils.misc as utils from datasets import build_dataset, get_coco_api_from_dataset -from models import build_model +from models import yolov5s from engine import train_one_epoch, evaluate def get_args_parser(): parser = argparse.ArgumentParser('You only look once detector', add_help=False) - parser.add_argument('--arch', default='ssd_lite_mobilenet_v2', + parser.add_argument('--arch', default='yolov5s', help='model architecture') parser.add_argument('--return-criterion', action='store_true', help='Should be enabled in training mode') @@ -134,9 +133,8 @@ def main(args): print('Creating model, always set args.return_criterion be True') args.return_criterion = True - model, criterion = build_model(args) + model = yolov5s() model.to(device) - criterion.to(device) model_without_ddp = model if args.distributed: @@ -154,14 +152,14 @@ def main(args): weight_decay=args.weight_decay, ) - if args.lr_scheduler == 'multi-step': + if args.lr_scheduler == 'cosine': + lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.t_max) + elif args.lr_scheduler == 'multi-step': lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( optimizer, milestones=args.lr_steps, gamma=args.lr_gamma, ) - elif args.lr_scheduler == 'cosine': - lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.t_max) else: raise ValueError(f'scheduler {args.lr_scheduler} not supported') @@ -182,7 +180,7 @@ def main(args): for epoch in range(args.start_epoch, args.epochs): if args.distributed: sampler_train.set_epoch(epoch) - train_one_epoch(model, criterion, optimizer, data_loader_train, device, epoch, args.print_freq) + train_one_epoch(model, optimizer, data_loader_train, device, epoch, args.print_freq) lr_scheduler.step() if args.output_dir: @@ -194,7 +192,7 @@ def main(args): 'args': args, 'epoch': epoch, }, - os.path.join(output_dir, 'model_{}.pth'.format(epoch)), + output_dir.joinpath(f'model_{epoch}.pth'), ) # evaluate after every epoch diff --git a/models/_utils.py b/models/_utils.py index bdc197dc..0a0e4afc 100644 --- a/models/_utils.py +++ b/models/_utils.py @@ -1,7 +1,7 @@ import math import torch -from torch import Tensor +from torch import nn, Tensor from torch.jit.annotations import Tuple, List from torchvision.ops import box_convert @@ -311,3 +311,33 @@ def smooth_l1_loss(input, target, beta: float = 1. / 9, size_average: bool = Tru if size_average: return loss.mean() return loss.sum() + + +class FocalLoss(nn.Module): + # Wraps focal loss around existing loss_fcn(), i.e. criteria = FocalLoss(nn.BCEWithLogitsLoss(), gamma=1.5) + def __init__(self, loss_fcn, gamma=1.5, alpha=0.25): + super().__init__() + self.loss_fcn = loss_fcn # must be nn.BCEWithLogitsLoss() + self.gamma = gamma + self.alpha = alpha + self.reduction = loss_fcn.reduction + self.loss_fcn.reduction = 'none' # required to apply FL to each element + + def forward(self, pred, true): + loss = self.loss_fcn(pred, true) + # p_t = torch.exp(-loss) + # loss *= self.alpha * (1.000001 - p_t) ** self.gamma # non-zero power for gradient stability + + # TF implementation https://github.com/tensorflow/addons/blob/v0.7.1/tensorflow_addons/losses/focal_loss.py + pred_prob = torch.sigmoid(pred) # prob from logits + p_t = true * pred_prob + (1 - true) * (1 - pred_prob) + alpha_factor = true * self.alpha + (1 - true) * (1 - self.alpha) + modulating_factor = (1.0 - p_t) ** self.gamma + loss *= alpha_factor * modulating_factor + + if self.reduction == 'mean': + return loss.mean() + elif self.reduction == 'sum': + return loss.sum() + else: # 'none' + return loss diff --git a/models/backbone.py b/models/backbone.py index d3ab848c..e4a28b07 100644 --- a/models/backbone.py +++ b/models/backbone.py @@ -188,7 +188,7 @@ def forward(self, x): return out -def darknet(cfg_path='yolov5s.yaml', pretrained=False): +def darknet(cfg_path='yolov5s.yaml'): cfg_path = Path(__file__).parent.absolute().joinpath(cfg_path) with open(cfg_path) as f: model_dict = yaml.load(f, Loader=yaml.FullLoader) diff --git a/models/box_head.py b/models/box_head.py index aaea63f2..8fb1c96e 100644 --- a/models/box_head.py +++ b/models/box_head.py @@ -3,16 +3,11 @@ from torch import nn, Tensor from torch.jit.annotations import Tuple, List, Dict, Optional -from torchvision.ops import batched_nms, box_iou +from torchvision.ops import batched_nms from . import _utils as det_utils - - -def _sum(x: List[Tensor]) -> Tensor: - res = x[0] - for i in x[1:]: - res = res + i - return res +from ._utils import FocalLoss +from utils.box_ops import bbox_iou class YoloHead(nn.Module): @@ -59,6 +54,11 @@ def forward(self, x: List[Tensor]) -> Tensor: return torch.cat(all_pred_logits, dim=1) +def smooth_BCE(eps=0.1): # https://github.com/ultralytics/yolov3/issues/238#issuecomment-598028441 + # return positive, negative label smoothing BCE targets + return 1.0 - 0.5 * eps, 0.5 * eps + + class SetCriterion(nn.Module): """This class computes the loss for YOLOv5. Arguments: @@ -74,6 +74,14 @@ def __init__( weights: Tuple[float, float, float, float] = (1.0, 1.0, 1.0, 1.0), fg_iou_thresh: float = 0.5, bg_iou_thresh: float = 0.4, + box: float = 0.05, # box loss gain + cls: float = 0.5, # cls loss gain + cls_pw: float = 1.0, # cls BCELoss positive_weight + obj: float = 1.0, # obj loss gain (scale with pixels) + obj_pw: float = 1.0, # obj BCELoss positive_weight + anchor_t: Tuple[float] = (1.0, 2.0, 8.0), # anchor-multiple threshold + gr: float = 1.0, # iou loss ratio (obj_loss = 1.0 or iou) + fl_gamma: float = 0.0, # focal loss gamma allow_low_quality_matches: bool = True, ) -> None: """ @@ -95,31 +103,117 @@ def __init__( def forward( self, + head_outputs: Tensor, targets: List[Dict[str, Tensor]], - bbox_regression: Tensor, - anchors: Tensor, + anchors_tuple: Tuple[Tensor, Tensor, Tensor], ) -> Dict[str, Tensor]: + """ This performs the loss computation. + Parameters: + head_outputs: dict of tensors, see the output specification of the model for the format + targets: list of dicts, such that len(targets) == batch_size. + The expected keys in each dict depends on the losses applied, see each loss' doc """ - Arguments: - targets (List[Dict[Tensor]]): ground-truth boxes present in the image - head_outputs (Dict[Tensor]) - anchor (List[Tensor]) - """ - matched_idxs = [] - for targets_per_image in targets: - if targets_per_image['boxes'].numel() == 0: - matched_idxs.append(torch.full((anchors.size(0),), -1, dtype=torch.int64)) - continue + regression_targets, labels = self.select_training_samples(targets, head_outputs, anchors_tuple) + losses = self.compute_loss(head_outputs, regression_targets, labels) - match_quality_matrix = box_iou(targets_per_image['boxes'], anchors) - matched_idxs.append(self.proposal_matcher(match_quality_matrix)) + return losses - return self.compute_loss(targets, bbox_regression, anchors, matched_idxs) + def select_training_samples( + self, + targets: List[Dict[str, Tensor]], + head_outputs: Tensor, + anchors_tuple: Tuple[Tensor, Tensor, Tensor], + ) -> Tuple[Tensor, Tensor]: + # get boxes indices for each anchors + boxes, labels = self.assign_targets_to_anchors(head_outputs, targets, anchors_tuple[0]) + + gt_locations = [] + for img_id in range(len(targets)): + locations = self.box_coder.encode(boxes[img_id], anchors_tuple[0]) + gt_locations.append(locations) + + regression_targets = torch.stack(gt_locations, 0) + labels = torch.stack(labels, 0) + + return regression_targets, labels + + def assign_targets_to_anchors( + self, + head_outputs: Tensor, + targets: List[Dict[str, Tensor]], + anchors: Tensor, + ) -> Tuple[List[Tensor], List[Tensor]]: + """Assign ground truth boxes and targets to anchors. + Args: + gt_boxes (List[Tensor]): with shape num_targets x 4, ground truth boxes + gt_labels (List[Tensor]): with shape num_targets, labels of targets + anchors (Tensor): with shape num_priors x 4, XYXY_REL BoxMode + Returns: + boxes (List[Tensor]): with shape num_priors x 4 real values for anchors. + labels (List[Tensor]): with shape num_priros, labels for anchors. + """ + device = anchors.device + num_layers = len(anchors) + # Build targets for compute_loss(), input targets(image,class,x,y,w,h) + num_anchors = anchors.shape[0] # number of anchors + num_targets = targets.shape[0] # number of targets + tcls, tbox, indices, anch = [], [], [], [] + gain = torch.ones(7, device=device) # normalized to gridspace gain + # same as .repeat_interleave(num_targets) + ai = torch.arange(num_anchors, device=device).float().view(num_anchors, 1).repeat(1, num_targets) + targets = torch.cat((targets.repeat(num_anchors, 1, 1), ai[:, :, None]), 2) # append anchor indices + + g = 0.5 # bias + off = torch.tensor([[0, 0], + [1, 0], [0, 1], [-1, 0], [0, -1], # j,k,l,m + # [1, 1], [1, -1], [-1, 1], [-1, -1], # jk,jm,lk,lm + ], device=targets.device).float() * g # offsets + + for i in range(num_layers): + anchors_per_layer = anchors[i] + gain[2:6] = torch.tensor(head_outputs[i].shape)[[3, 2, 3, 2]] # xyxy gain + + # Match targets to anchors + t = targets * gain + if num_targets: + # Matches + r = t[:, :, 4:6] / anchors[:, None] # wh ratio + j = torch.max(r, 1. / r).max(2)[0] < self.anchor_t # compare + # j = wh_iou(anchors, t[:, 4:6]) > self.iou_t # iou(3,n)=wh_iou(anchors(3,2), gwh(n,2)) + t = t[j] # filter + + # Offsets + gxy = t[:, 2:4] # grid xy + gxi = gain[[2, 3]] - gxy # inverse + j, k = ((gxy % 1. < g) & (gxy > 1.)).T + l, m = ((gxi % 1. < g) & (gxi > 1.)).T + j = torch.stack((torch.ones_like(j), j, k, l, m)) + t = t.repeat((5, 1, 1))[j] + offsets = (torch.zeros_like(gxy)[None] + off[:, None])[j] + else: + t = targets[0] + offsets = 0 + + # Define + b, c = t[:, :2].long().T # image, class + gxy = t[:, 2:4] # grid xy + gwh = t[:, 4:6] # grid wh + gij = (gxy - offsets).long() + gi, gj = gij.T # grid xy indices + + # Append + a = t[:, 6].long() # anchor indices + indices.append((b, a, gj.clamp_(0, gain[3] - 1), gi.clamp_(0, gain[2] - 1))) # image, anchor, grid indices + tbox.append(torch.cat((gxy - gij, gwh), 1)) # box + anch.append(anchors_per_layer[a]) # anchors + tcls.append(c) # class + + return tcls, tbox, indices, anch def compute_loss( self, + head_outputs: Tensor, targets: List[Dict[str, Tensor]], - bbox_regression: Tensor, anchors: Tensor, matched_idxs: List[Tensor], ) -> Dict[str, Tensor]: @@ -129,31 +223,66 @@ def compute_loss( targets: list of dicts, such that len(targets) == batch_size. The expected keys in each dict depends on the losses applied, see each loss' doc """ - losses = [] - - for targets_per_image, bbox_regression_per_image, matched_idxs_per_image in zip( - targets, bbox_regression, matched_idxs): - # determine only the foreground indices, ignore the rest - foreground_idxs_per_image = torch.where(matched_idxs_per_image >= 0)[0] - num_foreground = foreground_idxs_per_image.numel() - - # select only the foreground boxes - matched_gt_boxes_per_image = targets_per_image['boxes'][matched_idxs_per_image[foreground_idxs_per_image]] - bbox_regression_per_image = bbox_regression_per_image[foreground_idxs_per_image, :] - anchors = anchors[foreground_idxs_per_image, :] - - # compute the regression targets - target_regression = self.box_coder.encode_single(matched_gt_boxes_per_image, anchors) - - # compute the loss - losses.append(torch.nn.functional.l1_loss( - bbox_regression_per_image, - target_regression, - size_average=False - ) / max(1, num_foreground)) + device = anchors.device + num_classes = head_outputs.shape[2] - 5 + lcls, lbox, lobj = torch.zeros(1, device=device), torch.zeros(1, device=device), torch.zeros(1, device=device) + + # Define criteria + BCEcls = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([self.cls_pw])).to(device) + BCEobj = nn.BCEWithLogitsLoss(pos_weight=torch.Tensor([self.obj_pw])).to(device) + + # Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3 + cp, cn = smooth_BCE(eps=0.0) + + # Focal loss + g = self.fl_gamma # focal loss gamma + if g > 0: + BCEcls, BCEobj = FocalLoss(BCEcls, g), FocalLoss(BCEobj, g) + + # Losses + num_targets = 0 # number of targets + num_output = len(head_outputs) # number of outputs + balance = [4.0, 1.0, 0.4] if num_output == 3 else [4.0, 1.0, 0.4, 0.1] # P3-5 or P3-6 + for i, pi in enumerate(head_outputs): # layer index, layer predictions + b, a, gj, gi = matched_idxs[i] # image, anchor, gridy, gridx + tobj = torch.zeros_like(pi[..., 0], device=device) # target obj + + n = b.shape[0] # number of targets + if n: + num_targets += n # cumulative targets + ps = pi[b, a, gj, gi] # prediction subset corresponding to targets + + # Regression + pxy = ps[:, :2].sigmoid() * 2. - 0.5 + pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i] + pbox = torch.cat((pxy, pwh), 1).to(device) # predicted box + iou = bbox_iou(pbox.T, targets['boxes'][i], x1y1x2y2=False, CIoU=True) # iou(prediction, target) + lbox += (1.0 - iou).mean() # iou loss + + # Objectness + tobj[b, a, gj, gi] = (1.0 - self.gr) + self.gr * iou.detach().clamp(0).type(tobj.dtype) # iou ratio + + # Classification + if num_classes > 1: # cls loss (only if multiple classes) + t = torch.full_like(ps[:, 5:], cn, device=device) # targets + t[range(n), targets['labels'][i]] = cp + lcls += BCEcls(ps[:, 5:], t) # BCE + + # Append targets to text file + # with open('targets.txt', 'a') as file: + # [file.write('%11.5g ' * 4 % tuple(x) + '\n') for x in torch.cat((txy[i], twh[i]), 1)] + + lobj += BCEobj(pi[..., 4], tobj) * balance[i] # obj loss + + out_scaling = 3 / num_output # output count scaling + lbox *= self.box * out_scaling + lobj *= self.obj * out_scaling * (1.4 if num_output == 4 else 1.) + lcls *= self.cls * out_scaling return { - 'loss': _sum(losses) / max(1, len(targets)), + 'cls_logits': lcls, + 'bbox_regression': lbox, + 'objectness': lobj, } diff --git a/models/yolo.py b/models/yolo.py index e3e79d63..e728d45e 100644 --- a/models/yolo.py +++ b/models/yolo.py @@ -63,13 +63,13 @@ def __init__( anchor_generator = AnchorGenerator(strides, anchor_grids) self.anchor_generator = anchor_generator - if compute_loss is None: - compute_loss = SetCriterion( - weights=(1.0, 1.0, 1.0, 1.0), - fg_iou_thresh=fg_iou_thresh, - bg_iou_thresh=bg_iou_thresh, - ) - self.compute_loss = compute_loss + # if compute_loss is None: + # compute_loss = SetCriterion( + # weights=(1.0, 1.0, 1.0, 1.0), + # fg_iou_thresh=fg_iou_thresh, + # bg_iou_thresh=bg_iou_thresh, + # ) + self.compute_loss = None if head is None: head = YoloHead( @@ -140,13 +140,12 @@ def forward( # create the set of anchors anchors_tuple = self.anchor_generator(features) losses = {} - detections = torch.jit.annotate(List[Dict[str, Tensor]], []) + detections: List[Dict[str, Tensor]] = [] if self.training: assert targets is not None - # compute the losses - losses = self.compute_loss(targets, head_outputs, anchors_tuple[0]) + # losses = self.compute_loss(targets, head_outputs, anchors_tuple) else: # compute the detections detections = self.postprocess_detections(head_outputs, anchors_tuple, images.image_sizes) @@ -166,7 +165,6 @@ def yolov5( pretrained: bool = False, progress: bool = True, num_classes: int = 80, - pretrained_backbone: bool = True, **kwargs: Any, ) -> YOLO: """ @@ -205,11 +203,7 @@ def yolov5( pretrained (bool): If True, returns a model pre-trained on COCO train2017 progress (bool): If True, displays a progress bar of the download to stderr """ - if pretrained: - # no need to download the backbone if pretrained is set - pretrained_backbone = False - # skip P2 because it generates too many anchors (according to their paper) - backbone, anchor_grids = darknet(cfg_path=cfg_path, pretrained=pretrained_backbone) + backbone, anchor_grids = darknet(cfg_path=cfg_path) model = YOLO(backbone, num_classes, anchor_grids, **kwargs) if pretrained: state_dict = load_state_dict_from_url(model_urls[Path(cfg_path).stem], progress=progress) diff --git a/notebooks/export-onnx-inference-onnxruntime.ipynb b/notebooks/export-onnx-inference-onnxruntime.ipynb index c96d54f0..f2c35d4a 100644 --- a/notebooks/export-onnx-inference-onnxruntime.ipynb +++ b/notebooks/export-onnx-inference-onnxruntime.ipynb @@ -77,8 +77,8 @@ "metadata": {}, "outputs": [], "source": [ - "path0 = './notebooks/assets/bus.jpg'\n", - "path1 = './notebooks/assets/zidane.jpg'\n", + "path0 = './test/assets/bus.jpg'\n", + "path1 = './test/assets/zidane.jpg'\n", "\n", "img_test0 = read_image(path0, is_half=False)\n", "img_test0 = img_test0.to(device)\n", diff --git a/notebooks/inference-pytorch-export-libtorch.ipynb b/notebooks/inference-pytorch-export-libtorch.ipynb index f518e954..567f8755 100644 --- a/notebooks/inference-pytorch-export-libtorch.ipynb +++ b/notebooks/inference-pytorch-export-libtorch.ipynb @@ -88,8 +88,8 @@ "outputs": [], "source": [ "opt = DotDict({\n", - " 'input_source': './notebooks/assets/bus.jpg',\n", - " 'output_dir': './notebooks/assets/output',\n", + " 'input_source': './test/assets/bus.jpg',\n", + " 'output_dir': './test/assets/output',\n", " 'save_txt': False,\n", " 'save_img': True,\n", "})" diff --git a/notebooks/assets/bus.jpg b/test/assets/bus.jpg similarity index 100% rename from notebooks/assets/bus.jpg rename to test/assets/bus.jpg diff --git a/notebooks/assets/zidane.jpg b/test/assets/zidane.jpg similarity index 100% rename from notebooks/assets/zidane.jpg rename to test/assets/zidane.jpg diff --git a/test/test_engine.py b/test/test_engine.py new file mode 100644 index 00000000..ae771fda --- /dev/null +++ b/test/test_engine.py @@ -0,0 +1,52 @@ +import unittest +import torch + +from typing import Dict + +from models import yolov5s +from .torch_utils import image_preprocess + + +class EngineTester(unittest.TestCase): + @unittest.skip("Current it isn't well implemented") + def test_train(self): + # Read Image using TorchVision.io Here + # Do forward over image + img_name = "test/assets/zidane.jpg" + img_tensor = image_preprocess(img_name) + self.assertEqual(img_tensor.ndim, 3) + + images = [img_tensor] + boxes = torch.tensor([[0.3790, 0.5487, 0.3220, 0.2047], + [0.2680, 0.5386, 0.2200, 0.1779], + [0.1720, 0.5403, 0.1960, 0.1409], + [0.2240, 0.4547, 0.1520, 0.0705]], dtype=torch.float) + labels = torch.tensor([7, 2, 3, 4], dtype=torch.int64) + targets = [{"boxes": boxes, "labels": labels}] + + model = yolov5s(num_classes=12) + out = model(images, targets) + self.assertIsInstance(out, Dict) + self.assertIsInstance(out["loss_classifier"], torch.Tensor) + self.assertIsInstance(out["loss_box_reg"], torch.Tensor) + self.assertIsInstance(out["loss_objectness"], torch.Tensor) + + def test_inference(self): + # Infer over an image + img_name = "test/assets/zidane.jpg" + img_input = image_preprocess(img_name) + self.assertEqual(img_input.ndim, 3) + + model = yolov5s(pretrained=True) + model.eval() + + out = model([img_input]) + self.assertIsInstance(out, list) + self.assertIsInstance(out[0], Dict) + self.assertIsInstance(out[0]["boxes"], torch.Tensor) + self.assertIsInstance(out[0]["labels"], torch.Tensor) + self.assertIsInstance(out[0]["scores"], torch.Tensor) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/test_models.py b/test/test_models.py index 7b44c5ff..268c01ee 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -1,3 +1,4 @@ +import unittest import torch from models.anchor_utils import AnchorGenerator @@ -36,3 +37,7 @@ def test_anchor_generator(self): self.assertEqual(anchors[0], anchor_output) self.assertEqual(anchors[1], wh_output) self.assertEqual(anchors[2], xy_output) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/test_models_utils.py b/test/test_models_utils.py new file mode 100644 index 00000000..73c2b724 --- /dev/null +++ b/test/test_models_utils.py @@ -0,0 +1,34 @@ +import unittest +import copy + +import torch +from torchvision.models.detection.transform import GeneralizedRCNNTransform + +from models import _utils as det_utils + + +class UtilsTester(unittest.TestCase): + def test_balanced_positive_negative_sampler(self): + sampler = det_utils.BalancedPositiveNegativeSampler(4, 0.25) + # keep all 6 negatives first, then add 3 positives, last two are ignore + matched_idxs = [torch.tensor([0, 0, 0, 0, 0, 0, 1, 1, 1, -1, -1])] + pos, neg = sampler(matched_idxs) + # we know the number of elements that should be sampled for the positive (1) + # and the negative (3), and their location. Let's make sure that they are there + self.assertEqual(pos[0].sum(), 1) + self.assertEqual(pos[0][6:9].sum(), 1) + self.assertEqual(neg[0].sum(), 3) + self.assertEqual(neg[0][0:6].sum(), 3) + + def test_transform_copy_targets(self): + transform = GeneralizedRCNNTransform(300, 500, torch.zeros(3), torch.ones(3)) + image = [torch.rand(3, 200, 300), torch.rand(3, 200, 200)] + targets = [{'boxes': torch.rand(3, 4)}, {'boxes': torch.rand(2, 4)}] + targets_copy = copy.deepcopy(targets) + out = transform(image, targets) # noqa: F841 + self.assertTrue(torch.equal(targets[0]['boxes'], targets_copy[0]['boxes'])) + self.assertTrue(torch.equal(targets[1]['boxes'], targets_copy[1]['boxes'])) + + +if __name__ == '__main__': + unittest.main() diff --git a/test/torch_utils.py b/test/torch_utils.py new file mode 100644 index 00000000..7b389409 --- /dev/null +++ b/test/torch_utils.py @@ -0,0 +1,11 @@ +from torchvision.io import read_image + +__all__ = ["image_preprocess"] + + +def image_preprocess(img_name, is_half=False): + img = read_image(img_name) + img = img.half() if is_half else img.float() # uint8 to fp16/32 + img /= 255.0 # 0 - 255 to 0.0 - 1.0 + + return img diff --git a/utils/box_ops.py b/utils/box_ops.py new file mode 100644 index 00000000..1f132396 --- /dev/null +++ b/utils/box_ops.py @@ -0,0 +1,48 @@ +import math + +import torch + + +def bbox_iou(box1, box2, x1y1x2y2=True, GIoU=False, DIoU=False, CIoU=False, eps=1e-9): + # Returns the IoU of box1 to box2. box1 is 4, box2 is nx4 + box2 = box2.T + + # Get the coordinates of bounding boxes + if x1y1x2y2: # x1, y1, x2, y2 = box1 + b1_x1, b1_y1, b1_x2, b1_y2 = box1[0], box1[1], box1[2], box1[3] + b2_x1, b2_y1, b2_x2, b2_y2 = box2[0], box2[1], box2[2], box2[3] + else: # transform from xywh to xyxy + b1_x1, b1_x2 = box1[0] - box1[2] / 2, box1[0] + box1[2] / 2 + b1_y1, b1_y2 = box1[1] - box1[3] / 2, box1[1] + box1[3] / 2 + b2_x1, b2_x2 = box2[0] - box2[2] / 2, box2[0] + box2[2] / 2 + b2_y1, b2_y2 = box2[1] - box2[3] / 2, box2[1] + box2[3] / 2 + + # Intersection area + inter = (torch.min(b1_x2, b2_x2) - torch.max(b1_x1, b2_x1)).clamp(0) * \ + (torch.min(b1_y2, b2_y2) - torch.max(b1_y1, b2_y1)).clamp(0) + + # Union Area + w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps + w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps + union = w1 * h1 + w2 * h2 - inter + eps + + iou = inter / union + if GIoU or DIoU or CIoU: + cw = torch.max(b1_x2, b2_x2) - torch.min(b1_x1, b2_x1) # convex (smallest enclosing box) width + ch = torch.max(b1_y2, b2_y2) - torch.min(b1_y1, b2_y1) # convex height + if CIoU or DIoU: # Distance or Complete IoU https://arxiv.org/abs/1911.08287v1 + c2 = cw ** 2 + ch ** 2 + eps # convex diagonal squared + # center distance squared + rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 + if DIoU: + return iou - rho2 / c2 # DIoU + elif CIoU: # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47 + v = (4 / math.pi ** 2) * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2) + with torch.no_grad(): + alpha = v / ((1 + eps) - iou + v) + return iou - (rho2 / c2 + v * alpha) # CIoU + else: # GIoU https://arxiv.org/pdf/1902.09630.pdf + c_area = cw * ch + eps # convex area + return iou - (c_area - union) / c_area # GIoU + else: + return iou # IoU diff --git a/utils/misc.py b/utils/misc.py index 1d4e5eb1..0a34d21f 100644 --- a/utils/misc.py +++ b/utils/misc.py @@ -5,22 +5,13 @@ Mostly copy-paste from torchvision references. """ import os -import subprocess import time from collections import defaultdict, deque import datetime import pickle -from typing import Optional, List import torch import torch.distributed as dist -from torch import Tensor - -# needed due to empty tensor bug in pytorch and torchvision 0.5 -import torchvision -if float(torchvision.__version__[:3]) < 0.7: - from torchvision.ops import _new_empty_tensor - from torchvision.ops.misc import _output_size class SmoothedValue(object): @@ -245,118 +236,19 @@ def log_every(self, iterable, print_freq, header=None): header, total_time_str, total_time / len(iterable))) -def get_sha(): - cwd = os.path.dirname(os.path.abspath(__file__)) - - def _run(command): - return subprocess.check_output(command, cwd=cwd).decode('ascii').strip() - sha = 'N/A' - diff = "clean" - branch = 'N/A' - try: - sha = _run(['git', 'rev-parse', 'HEAD']) - subprocess.check_output(['git', 'diff'], cwd=cwd) - diff = _run(['git', 'diff-index', 'HEAD']) - diff = "has uncommited changes" if diff else "clean" - branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD']) - except Exception: - pass - message = f"sha: {sha}, status: {diff}, branch: {branch}" - return message - - def collate_fn(batch): - batch = list(zip(*batch)) - batch[0] = nested_tensor_from_tensor_list(batch[0]) - return tuple(batch) - - -def _max_by_axis(the_list): - # type: (List[List[int]]) -> List[int] - maxes = the_list[0] - for sublist in the_list[1:]: - for index, item in enumerate(sublist): - maxes[index] = max(maxes[index], item) - return maxes - - -class NestedTensor(object): - def __init__(self, tensors, mask: Optional[Tensor]): - self.tensors = tensors - self.mask = mask - - def to(self, device): - # type: (Device) -> NestedTensor # noqa - cast_tensor = self.tensors.to(device) - mask = self.mask - if mask is not None: - assert mask is not None - cast_mask = mask.to(device) - else: - cast_mask = None - return NestedTensor(cast_tensor, cast_mask) - - def decompose(self): - return self.tensors, self.mask - - def __repr__(self): - return str(self.tensors) - - -def nested_tensor_from_tensor_list(tensor_list: List[Tensor]): - # TODO make this more general - if tensor_list[0].ndim == 3: - if torchvision._is_tracing(): - # nested_tensor_from_tensor_list() does not export well to ONNX - # call _onnx_nested_tensor_from_tensor_list() instead - return _onnx_nested_tensor_from_tensor_list(tensor_list) - - # TODO make it support different-sized images - max_size = _max_by_axis([list(img.shape) for img in tensor_list]) - # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list])) - batch_shape = [len(tensor_list)] + max_size - b, c, h, w = batch_shape - dtype = tensor_list[0].dtype - device = tensor_list[0].device - tensor = torch.zeros(batch_shape, dtype=dtype, device=device) - mask = torch.ones((b, h, w), dtype=torch.bool, device=device) - for img, pad_img, m in zip(tensor_list, tensor, mask): - pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) - m[: img.shape[1], :img.shape[2]] = False - else: - raise ValueError('not supported') - return NestedTensor(tensor, mask) - - -# _onnx_nested_tensor_from_tensor_list() is an implementation of -# nested_tensor_from_tensor_list() that is supported by ONNX tracing. -@torch.jit.unused -def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor: - max_size = [] - for i in range(tensor_list[0].dim()): - max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64) - max_size.append(max_size_i) - max_size = tuple(max_size) + return list(zip(*batch)) - # work around for - # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img) - # m[: img.shape[1], :img.shape[2]] = False - # which is not yet supported in onnx - padded_imgs = [] - padded_masks = [] - for img in tensor_list: - padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))] - padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0])) - padded_imgs.append(padded_img) - m = torch.zeros_like(img[0], dtype=torch.int, device=img.device) - padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1) - padded_masks.append(padded_mask.to(torch.bool)) +def warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor): - tensor = torch.stack(padded_imgs) - mask = torch.stack(padded_masks) + def f(x): + if x >= warmup_iters: + return 1 + alpha = float(x) / warmup_iters + return warmup_factor * (1 - alpha) + alpha - return NestedTensor(tensor, mask=mask) + return torch.optim.lr_scheduler.LambdaLR(optimizer, f) def setup_for_distributed(is_master): @@ -445,23 +337,3 @@ def accuracy(output, target, topk=(1,)): correct_k = correct[:k].view(-1).float().sum(0) res.append(correct_k.mul_(100.0 / batch_size)) return res - - -def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None): - # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor - """ - Equivalent to nn.functional.interpolate, but with support for empty batch sizes. - This will eventually be supported natively by PyTorch, and this - class can go away. - """ - if float(torchvision.__version__[:3]) < 0.7: - if input.numel() > 0: - return torch.nn.functional.interpolate( - input, size, scale_factor, mode, align_corners - ) - - output_shape = _output_size(2, input, size, scale_factor) - output_shape = list(input.shape[:-2]) + list(output_shape) - return _new_empty_tensor(input, output_shape) - else: - return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)