From 1804ab73dd17ee9967d1e301454317285b8feb8e Mon Sep 17 00:00:00 2001 From: Xiang Zhang Date: Thu, 7 Apr 2022 09:21:15 -0700 Subject: [PATCH 1/6] add samples for object detection models --- PyTorch/.gitignore | 6 +- PyTorch/data/cifar.py | 24 -- PyTorch/data/dataset.py | 47 +++ PyTorch/objectDetection/maskrcnn/README.md | 51 +++ PyTorch/objectDetection/maskrcnn/coco_eval.py | 352 ++++++++++++++++++ .../objectDetection/maskrcnn/coco_utils.py | 252 +++++++++++++ PyTorch/objectDetection/maskrcnn/engine.py | 109 ++++++ PyTorch/objectDetection/maskrcnn/maskrcnn.py | 172 +++++++++ .../objectDetection/maskrcnn/requirements.txt | 3 + .../objectDetection/maskrcnn/transforms.py | 49 +++ PyTorch/objectDetection/maskrcnn/utils.py | 324 ++++++++++++++++ PyTorch/objectDetection/objectDetection.py | 72 ++++ PyTorch/resnet50/README.md | 4 +- PyTorch/squeezenet/README.md | 4 +- PyTorch/torchvision_classification/README.md | 39 +- 15 files changed, 1454 insertions(+), 54 deletions(-) delete mode 100644 PyTorch/data/cifar.py create mode 100644 PyTorch/data/dataset.py create mode 100644 PyTorch/objectDetection/maskrcnn/README.md create mode 100644 PyTorch/objectDetection/maskrcnn/coco_eval.py create mode 100644 PyTorch/objectDetection/maskrcnn/coco_utils.py create mode 100644 PyTorch/objectDetection/maskrcnn/engine.py create mode 100644 PyTorch/objectDetection/maskrcnn/maskrcnn.py create mode 100644 PyTorch/objectDetection/maskrcnn/requirements.txt create mode 100644 PyTorch/objectDetection/maskrcnn/transforms.py create mode 100644 PyTorch/objectDetection/maskrcnn/utils.py create mode 100644 PyTorch/objectDetection/objectDetection.py diff --git a/PyTorch/.gitignore b/PyTorch/.gitignore index 5a0f2bd3..c3c84194 100644 --- a/PyTorch/.gitignore +++ b/PyTorch/.gitignore @@ -6,4 +6,8 @@ traces/ *.xlsx data/cifar-10-python checkpoints -coco128 \ No newline at end of file +coco128 +data/PennFudanPed +data/PennFudanPed.zip +objectDetection/PennFudanPed +checkpoints \ No newline at end of file diff --git a/PyTorch/data/cifar.py b/PyTorch/data/cifar.py deleted file mode 100644 index aaa2d709..00000000 --- a/PyTorch/data/cifar.py +++ /dev/null @@ -1,24 +0,0 @@ -#!/usr/bin/env python -# Copyright (c) Microsoft Corporation. All rights reserved. - -import argparse -import pathlib -import os -import subprocess -from torchvision import datasets -from torchvision.transforms import ToTensor, Lambda, Compose, transforms - -def get_training_path(args): - if (os.path.isabs(args.path)): - return args.path - else: - return str(os.path.join(pathlib.Path(__file__).parent.resolve(), args.path)) - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(__doc__) - parser.add_argument("-path", help="Path to cifar dataset.", default="cifar-10-python") - args = parser.parse_args() - - path = get_training_path(args) - datasets.CIFAR10(root=path, download=True) \ No newline at end of file diff --git a/PyTorch/data/dataset.py b/PyTorch/data/dataset.py new file mode 100644 index 00000000..faf55b5d --- /dev/null +++ b/PyTorch/data/dataset.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python +# Copyright (c) Microsoft Corporation. All rights reserved. + +import argparse +import pathlib +import os +from torchvision import datasets +import wget +import zipfile + +def get_current_dir(): + return str(pathlib.Path(__file__).parent.resolve()) + +def download_cifar_dataset(): + path = os.path.join(get_current_dir(), 'cifar-10-python') + datasets.CIFAR10(root=path, download=True) + +def download_pennfudanped_dataset(): + path = get_current_dir() + if (os.path.exists(os.path.join(path, 'PennFudanPed'))): + print ("PennFundaPed dataset already downloaded and verified") + return + + url='https://www.cis.upenn.edu/~jshi/ped_html/PennFudanPed.zip' + print("Downloading PennFundaPed dataset\n") + dataset_path = wget.download(url, out=path) + try: + with zipfile.ZipFile(os.path.join(path, dataset_path)) as z: + z.extractall(path=path) + print("\nExtracted PennFundaPed dataset") + except: + print("Invalid file") + +if __name__ == "__main__": + parser = argparse.ArgumentParser(__doc__) + parser.add_argument("--dataset", help="datasets: cifar10 or pennfudanped.", default="all") + args = parser.parse_args() + + if args.dataset.lower() == 'all': + download_cifar_dataset() + download_pennfudanped_dataset() + elif args.dataset.lower() == 'cifar10': + download_cifar_dataset() + elif args.dataset.lower() == 'pennfudanped': + download_pennfudanped_dataset() + else: + raise Exception(f"Model {args.dataset} is not supported yet!") \ No newline at end of file diff --git a/PyTorch/objectDetection/maskrcnn/README.md b/PyTorch/objectDetection/maskrcnn/README.md new file mode 100644 index 00000000..f92031da --- /dev/null +++ b/PyTorch/objectDetection/maskrcnn/README.md @@ -0,0 +1,51 @@ +# maskrcnn Model + +Sample scripts for training the [Mask R-CNN](https://arxiv.org/abs/1703.06870) model in the [Penn-Fudan Database for Pedestrian Detection and Segmentation](https://www.cis.upenn.edu/~jshi/ped_html/) using PyTorch on DirectML + +These scripts are collected from the tutorial [here](https://pytorch.org/tutorials/intermediate/torchvision_tutorial.html) + +- [Setup](#setup) +- [Prepare Data](#prepare-data) +- [Training](#training) + +## Setup +Install the following prerequisites: +``` +pip install -r pytorch\objectDetection\maskrcnn\requirements.txt +``` + +## Prepare Data + +After installing the PyTorch on DirectML package (see [GPU accelerated ML training](..\readme.md)), open a console to the `root` directory and run the setup script to download and convert data: + +``` +python pytorch\data\dataset.py +``` + +Running `setup.py` should take at least a minute or so, since it downloads the CIFAR-10 dataset. The output of running it should look similar to the following: + +``` +>python pytorch\data\dataset.py +Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to E:\work\dml\PyTorch\data\cifar-10-python\cifar-10-python.tar.gz +Failed download. Trying https -> http instead. Downloading http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to E:\work\dml\PyTorch\data\cifar-10-python\cifar-10-python.tar.gz +170499072it [00:32, 5250164.09it/s] +Extracting E:\work\dml\PyTorch\data\cifar-10-python\cifar-10-python.tar.gz to E:\work\dml\PyTorch\data\cifar-10-python +``` + +## Training + +A helper script exists to train Mask R-CNN with PennFudanPed data: + +``` + +python .\maskrcnn.py +``` + +The first few lines of output should look similar to the following (exact numbers may change): +``` +>python .\maskrcnn.py +python .\maskrcnn.py +Epoch: [0] [ 0/60] eta: 0:38:26 lr: 0.000090 loss: 2.9777 (2.9777) loss_classifier: 0.7217 (0.7217) loss_box_reg: 0.0754 (0.0754) loss_mask: 1.6228 (1.6228) loss_objectness: 0.4175 (0.4175) loss_rpn_box_reg: 0.1404 (0.1404) time: 38.4439 data: 1.0955 +Epoch: [0] [10/60] eta: 0:29:44 lr: 0.000936 loss: 2.4268 (2.4919) loss_classifier: 0.4056 (0.4158) loss_box_reg: 0.1691 (0.3631) loss_mask: 1.1679 (1.1600) loss_objectness: 0.1162 (0.3120) loss_rpn_box_reg: 0.1257 (0.2410) time: 35.6972 data: 0.1034 +Epoch: [0] [20/60] eta: 0:23:14 lr: 0.001783 loss: 1.2172 (1.6717) loss_classifier: 0.0669 (0.2410) loss_box_reg: 0.1331 (0.2466) loss_mask: 0.5935 (0.8376) loss_objectness: 0.0565 (0.1873) loss_rpn_box_reg: 0.0574 (0.1593) time: 34.6860 data: 0.0042 +``` diff --git a/PyTorch/objectDetection/maskrcnn/coco_eval.py b/PyTorch/objectDetection/maskrcnn/coco_eval.py new file mode 100644 index 00000000..09648f29 --- /dev/null +++ b/PyTorch/objectDetection/maskrcnn/coco_eval.py @@ -0,0 +1,352 @@ +import json +import tempfile + +import numpy as np +import copy +import time +import torch +import torch._six + +from pycocotools.cocoeval import COCOeval +from pycocotools.coco import COCO +import pycocotools.mask as mask_util + +from collections import defaultdict + +import utils + + +class CocoEvaluator(object): + def __init__(self, coco_gt, iou_types): + assert isinstance(iou_types, (list, tuple)) + coco_gt = copy.deepcopy(coco_gt) + self.coco_gt = coco_gt + + self.iou_types = iou_types + self.coco_eval = {} + for iou_type in iou_types: + self.coco_eval[iou_type] = COCOeval(coco_gt, iouType=iou_type) + + self.img_ids = [] + self.eval_imgs = {k: [] for k in iou_types} + + def update(self, predictions): + img_ids = list(np.unique(list(predictions.keys()))) + self.img_ids.extend(img_ids) + + for iou_type in self.iou_types: + results = self.prepare(predictions, iou_type) + coco_dt = loadRes(self.coco_gt, results) if results else COCO() + coco_eval = self.coco_eval[iou_type] + + coco_eval.cocoDt = coco_dt + coco_eval.params.imgIds = list(img_ids) + img_ids, eval_imgs = evaluate(coco_eval) + + self.eval_imgs[iou_type].append(eval_imgs) + + def synchronize_between_processes(self): + for iou_type in self.iou_types: + self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2) + create_common_coco_eval(self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type]) + + def accumulate(self): + for coco_eval in self.coco_eval.values(): + coco_eval.accumulate() + + def summarize(self): + for iou_type, coco_eval in self.coco_eval.items(): + print("IoU metric: {}".format(iou_type)) + coco_eval.summarize() + + def prepare(self, predictions, iou_type): + if iou_type == "bbox": + return self.prepare_for_coco_detection(predictions) + elif iou_type == "segm": + return self.prepare_for_coco_segmentation(predictions) + elif iou_type == "keypoints": + return self.prepare_for_coco_keypoint(predictions) + else: + raise ValueError("Unknown iou type {}".format(iou_type)) + + def prepare_for_coco_detection(self, predictions): + coco_results = [] + for original_id, prediction in predictions.items(): + if len(prediction) == 0: + continue + + boxes = prediction["boxes"] + boxes = convert_to_xywh(boxes).tolist() + scores = prediction["scores"].tolist() + labels = prediction["labels"].tolist() + + coco_results.extend( + [ + { + "image_id": original_id, + "category_id": labels[k], + "bbox": box, + "score": scores[k], + } + for k, box in enumerate(boxes) + ] + ) + return coco_results + + def prepare_for_coco_segmentation(self, predictions): + coco_results = [] + for original_id, prediction in predictions.items(): + if len(prediction) == 0: + continue + + scores = prediction["scores"] + labels = prediction["labels"] + masks = prediction["masks"] + + masks = masks > 0.5 + + scores = prediction["scores"].tolist() + labels = prediction["labels"].tolist() + + rles = [ + mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0] + for mask in masks + ] + for rle in rles: + rle["counts"] = rle["counts"].decode("utf-8") + + coco_results.extend( + [ + { + "image_id": original_id, + "category_id": labels[k], + "segmentation": rle, + "score": scores[k], + } + for k, rle in enumerate(rles) + ] + ) + return coco_results + + def prepare_for_coco_keypoint(self, predictions): + coco_results = [] + for original_id, prediction in predictions.items(): + if len(prediction) == 0: + continue + + boxes = prediction["boxes"] + boxes = convert_to_xywh(boxes).tolist() + scores = prediction["scores"].tolist() + labels = prediction["labels"].tolist() + keypoints = prediction["keypoints"] + keypoints = keypoints.flatten(start_dim=1).tolist() + + coco_results.extend( + [ + { + "image_id": original_id, + "category_id": labels[k], + 'keypoints': keypoint, + "score": scores[k], + } + for k, keypoint in enumerate(keypoints) + ] + ) + return coco_results + + +def convert_to_xywh(boxes): + xmin, ymin, xmax, ymax = boxes.unbind(1) + return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1) + + +def merge(img_ids, eval_imgs): + all_img_ids = utils.all_gather(img_ids) + all_eval_imgs = utils.all_gather(eval_imgs) + + merged_img_ids = [] + for p in all_img_ids: + merged_img_ids.extend(p) + + merged_eval_imgs = [] + for p in all_eval_imgs: + merged_eval_imgs.append(p) + + merged_img_ids = np.array(merged_img_ids) + merged_eval_imgs = np.concatenate(merged_eval_imgs, 2) + + # keep only unique (and in sorted order) images + merged_img_ids, idx = np.unique(merged_img_ids, return_index=True) + merged_eval_imgs = merged_eval_imgs[..., idx] + + return merged_img_ids, merged_eval_imgs + + +def create_common_coco_eval(coco_eval, img_ids, eval_imgs): + img_ids, eval_imgs = merge(img_ids, eval_imgs) + img_ids = list(img_ids) + eval_imgs = list(eval_imgs.flatten()) + + coco_eval.evalImgs = eval_imgs + coco_eval.params.imgIds = img_ids + coco_eval._paramsEval = copy.deepcopy(coco_eval.params) + + +################################################################# +# From pycocotools, just removed the prints and fixed +# a Python3 bug about unicode not defined +################################################################# + +# Ideally, pycocotools wouldn't have hard-coded prints +# so that we could avoid copy-pasting those two functions + +def createIndex(self): + # create index + # print('creating index...') + anns, cats, imgs = {}, {}, {} + imgToAnns, catToImgs = defaultdict(list), defaultdict(list) + if 'annotations' in self.dataset: + for ann in self.dataset['annotations']: + imgToAnns[ann['image_id']].append(ann) + anns[ann['id']] = ann + + if 'images' in self.dataset: + for img in self.dataset['images']: + imgs[img['id']] = img + + if 'categories' in self.dataset: + for cat in self.dataset['categories']: + cats[cat['id']] = cat + + if 'annotations' in self.dataset and 'categories' in self.dataset: + for ann in self.dataset['annotations']: + catToImgs[ann['category_id']].append(ann['image_id']) + + # print('index created!') + + # create class members + self.anns = anns + self.imgToAnns = imgToAnns + self.catToImgs = catToImgs + self.imgs = imgs + self.cats = cats + + +maskUtils = mask_util + + +def loadRes(self, resFile): + """ + Load result file and return a result api object. + Args: + self (obj): coco object with ground truth annotations + resFile (str): file name of result file + Returns: + res (obj): result api object + """ + res = COCO() + res.dataset['images'] = [img for img in self.dataset['images']] + + # print('Loading and preparing results...') + # tic = time.time() + if isinstance(resFile, torch._six.string_classes): + anns = json.load(open(resFile)) + elif type(resFile) == np.ndarray: + anns = self.loadNumpyAnnotations(resFile) + else: + anns = resFile + assert type(anns) == list, 'results in not an array of objects' + annsImgIds = [ann['image_id'] for ann in anns] + assert set(annsImgIds) == (set(annsImgIds) & set(self.getImgIds())), \ + 'Results do not correspond to current coco set' + if 'caption' in anns[0]: + imgIds = set([img['id'] for img in res.dataset['images']]) & set([ann['image_id'] for ann in anns]) + res.dataset['images'] = [img for img in res.dataset['images'] if img['id'] in imgIds] + for id, ann in enumerate(anns): + ann['id'] = id + 1 + elif 'bbox' in anns[0] and not anns[0]['bbox'] == []: + res.dataset['categories'] = copy.deepcopy(self.dataset['categories']) + for id, ann in enumerate(anns): + bb = ann['bbox'] + x1, x2, y1, y2 = [bb[0], bb[0] + bb[2], bb[1], bb[1] + bb[3]] + if 'segmentation' not in ann: + ann['segmentation'] = [[x1, y1, x1, y2, x2, y2, x2, y1]] + ann['area'] = bb[2] * bb[3] + ann['id'] = id + 1 + ann['iscrowd'] = 0 + elif 'segmentation' in anns[0]: + res.dataset['categories'] = copy.deepcopy(self.dataset['categories']) + for id, ann in enumerate(anns): + # now only support compressed RLE format as segmentation results + ann['area'] = maskUtils.area(ann['segmentation']) + if 'bbox' not in ann: + ann['bbox'] = maskUtils.toBbox(ann['segmentation']) + ann['id'] = id + 1 + ann['iscrowd'] = 0 + elif 'keypoints' in anns[0]: + res.dataset['categories'] = copy.deepcopy(self.dataset['categories']) + for id, ann in enumerate(anns): + s = ann['keypoints'] + x = s[0::3] + y = s[1::3] + x1, x2, y1, y2 = np.min(x), np.max(x), np.min(y), np.max(y) + ann['area'] = (x2 - x1) * (y2 - y1) + ann['id'] = id + 1 + ann['bbox'] = [x1, y1, x2 - x1, y2 - y1] + # print('DONE (t={:0.2f}s)'.format(time.time()- tic)) + + res.dataset['annotations'] = anns + createIndex(res) + return res + + +def evaluate(self): + ''' + Run per image evaluation on given images and store results (a list of dict) in self.evalImgs + :return: None + ''' + # tic = time.time() + # print('Running per image evaluation...') + p = self.params + # add backward compatibility if useSegm is specified in params + if p.useSegm is not None: + p.iouType = 'segm' if p.useSegm == 1 else 'bbox' + print('useSegm (deprecated) is not None. Running {} evaluation'.format(p.iouType)) + # print('Evaluate annotation type *{}*'.format(p.iouType)) + p.imgIds = list(np.unique(p.imgIds)) + if p.useCats: + p.catIds = list(np.unique(p.catIds)) + p.maxDets = sorted(p.maxDets) + self.params = p + + self._prepare() + # loop through images, area range, max detection number + catIds = p.catIds if p.useCats else [-1] + + if p.iouType == 'segm' or p.iouType == 'bbox': + computeIoU = self.computeIoU + elif p.iouType == 'keypoints': + computeIoU = self.computeOks + self.ious = { + (imgId, catId): computeIoU(imgId, catId) + for imgId in p.imgIds + for catId in catIds} + + evaluateImg = self.evaluateImg + maxDet = p.maxDets[-1] + evalImgs = [ + evaluateImg(imgId, catId, areaRng, maxDet) + for catId in catIds + for areaRng in p.areaRng + for imgId in p.imgIds + ] + # this is NOT in the pycocotools code, but could be done outside + evalImgs = np.asarray(evalImgs).reshape(len(catIds), len(p.areaRng), len(p.imgIds)) + self._paramsEval = copy.deepcopy(self.params) + # toc = time.time() + # print('DONE (t={:0.2f}s).'.format(toc-tic)) + return p.imgIds, evalImgs + +################################################################# +# end of straight copy from pycocotools, just removing the prints +################################################################# diff --git a/PyTorch/objectDetection/maskrcnn/coco_utils.py b/PyTorch/objectDetection/maskrcnn/coco_utils.py new file mode 100644 index 00000000..26701a2c --- /dev/null +++ b/PyTorch/objectDetection/maskrcnn/coco_utils.py @@ -0,0 +1,252 @@ +import copy +import os +from PIL import Image + +import torch +import torch.utils.data +import torchvision + +from pycocotools import mask as coco_mask +from pycocotools.coco import COCO + +import transforms as T + + +class FilterAndRemapCocoCategories(object): + def __init__(self, categories, remap=True): + self.categories = categories + self.remap = remap + + def __call__(self, image, target): + anno = target["annotations"] + anno = [obj for obj in anno if obj["category_id"] in self.categories] + if not self.remap: + target["annotations"] = anno + return image, target + anno = copy.deepcopy(anno) + for obj in anno: + obj["category_id"] = self.categories.index(obj["category_id"]) + target["annotations"] = anno + return image, target + + +def convert_coco_poly_to_mask(segmentations, height, width): + masks = [] + for polygons in segmentations: + rles = coco_mask.frPyObjects(polygons, height, width) + mask = coco_mask.decode(rles) + if len(mask.shape) < 3: + mask = mask[..., None] + mask = torch.as_tensor(mask, dtype=torch.uint8) + mask = mask.any(dim=2) + masks.append(mask) + if masks: + masks = torch.stack(masks, dim=0) + else: + masks = torch.zeros((0, height, width), dtype=torch.uint8) + return masks + + +class ConvertCocoPolysToMask(object): + def __call__(self, image, target): + w, h = image.size + + image_id = target["image_id"] + image_id = torch.tensor([image_id]) + + anno = target["annotations"] + + anno = [obj for obj in anno if obj['iscrowd'] == 0] + + boxes = [obj["bbox"] for obj in anno] + # guard against no boxes via resizing + boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) + boxes[:, 2:] += boxes[:, :2] + boxes[:, 0::2].clamp_(min=0, max=w) + boxes[:, 1::2].clamp_(min=0, max=h) + + classes = [obj["category_id"] for obj in anno] + classes = torch.tensor(classes, dtype=torch.int64) + + segmentations = [obj["segmentation"] for obj in anno] + masks = convert_coco_poly_to_mask(segmentations, h, w) + + keypoints = None + if anno and "keypoints" in anno[0]: + keypoints = [obj["keypoints"] for obj in anno] + keypoints = torch.as_tensor(keypoints, dtype=torch.float32) + num_keypoints = keypoints.shape[0] + if num_keypoints: + keypoints = keypoints.view(num_keypoints, -1, 3) + + keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) + boxes = boxes[keep] + classes = classes[keep] + masks = masks[keep] + if keypoints is not None: + keypoints = keypoints[keep] + + target = {} + target["boxes"] = boxes + target["labels"] = classes + target["masks"] = masks + target["image_id"] = image_id + if keypoints is not None: + target["keypoints"] = keypoints + + # for conversion to coco api + area = torch.tensor([obj["area"] for obj in anno]) + iscrowd = torch.tensor([obj["iscrowd"] for obj in anno]) + target["area"] = area + target["iscrowd"] = iscrowd + + return image, target + + +def _coco_remove_images_without_annotations(dataset, cat_list=None): + def _has_only_empty_bbox(anno): + return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno) + + def _count_visible_keypoints(anno): + return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno) + + min_keypoints_per_image = 10 + + def _has_valid_annotation(anno): + # if it's empty, there is no annotation + if len(anno) == 0: + return False + # if all boxes have close to zero area, there is no annotation + if _has_only_empty_bbox(anno): + return False + # keypoints task have a slight different critera for considering + # if an annotation is valid + if "keypoints" not in anno[0]: + return True + # for keypoint detection tasks, only consider valid images those + # containing at least min_keypoints_per_image + if _count_visible_keypoints(anno) >= min_keypoints_per_image: + return True + return False + + assert isinstance(dataset, torchvision.datasets.CocoDetection) + ids = [] + for ds_idx, img_id in enumerate(dataset.ids): + ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None) + anno = dataset.coco.loadAnns(ann_ids) + if cat_list: + anno = [obj for obj in anno if obj["category_id"] in cat_list] + if _has_valid_annotation(anno): + ids.append(ds_idx) + + dataset = torch.utils.data.Subset(dataset, ids) + return dataset + + +def convert_to_coco_api(ds): + coco_ds = COCO() + # annotation IDs need to start at 1, not 0, see torchvision issue #1530 + ann_id = 1 + dataset = {'images': [], 'categories': [], 'annotations': []} + categories = set() + for img_idx in range(len(ds)): + # find better way to get target + # targets = ds.get_annotations(img_idx) + img, targets = ds[img_idx] + image_id = targets["image_id"].item() + img_dict = {} + img_dict['id'] = image_id + img_dict['height'] = img.shape[-2] + img_dict['width'] = img.shape[-1] + dataset['images'].append(img_dict) + bboxes = targets["boxes"] + bboxes[:, 2:] -= bboxes[:, :2] + bboxes = bboxes.tolist() + labels = targets['labels'].tolist() + areas = targets['area'].tolist() + iscrowd = targets['iscrowd'].tolist() + if 'masks' in targets: + masks = targets['masks'] + # make masks Fortran contiguous for coco_mask + masks = masks.permute(0, 2, 1).contiguous().permute(0, 2, 1) + if 'keypoints' in targets: + keypoints = targets['keypoints'] + keypoints = keypoints.reshape(keypoints.shape[0], -1).tolist() + num_objs = len(bboxes) + for i in range(num_objs): + ann = {} + ann['image_id'] = image_id + ann['bbox'] = bboxes[i] + ann['category_id'] = labels[i] + categories.add(labels[i]) + ann['area'] = areas[i] + ann['iscrowd'] = iscrowd[i] + ann['id'] = ann_id + if 'masks' in targets: + ann["segmentation"] = coco_mask.encode(masks[i].numpy()) + if 'keypoints' in targets: + ann['keypoints'] = keypoints[i] + ann['num_keypoints'] = sum(k != 0 for k in keypoints[i][2::3]) + dataset['annotations'].append(ann) + ann_id += 1 + dataset['categories'] = [{'id': i} for i in sorted(categories)] + coco_ds.dataset = dataset + coco_ds.createIndex() + return coco_ds + + +def get_coco_api_from_dataset(dataset): + for _ in range(10): + if isinstance(dataset, torchvision.datasets.CocoDetection): + break + if isinstance(dataset, torch.utils.data.Subset): + dataset = dataset.dataset + if isinstance(dataset, torchvision.datasets.CocoDetection): + return dataset.coco + return convert_to_coco_api(dataset) + + +class CocoDetection(torchvision.datasets.CocoDetection): + def __init__(self, img_folder, ann_file, transforms): + super(CocoDetection, self).__init__(img_folder, ann_file) + self._transforms = transforms + + def __getitem__(self, idx): + img, target = super(CocoDetection, self).__getitem__(idx) + image_id = self.ids[idx] + target = dict(image_id=image_id, annotations=target) + if self._transforms is not None: + img, target = self._transforms(img, target) + return img, target + + +def get_coco(root, image_set, transforms, mode='instances'): + anno_file_template = "{}_{}2017.json" + PATHS = { + "train": ("train2017", os.path.join("annotations", anno_file_template.format(mode, "train"))), + "val": ("val2017", os.path.join("annotations", anno_file_template.format(mode, "val"))), + # "train": ("val2017", os.path.join("annotations", anno_file_template.format(mode, "val"))) + } + + t = [ConvertCocoPolysToMask()] + + if transforms is not None: + t.append(transforms) + transforms = T.Compose(t) + + img_folder, ann_file = PATHS[image_set] + img_folder = os.path.join(root, img_folder) + ann_file = os.path.join(root, ann_file) + + dataset = CocoDetection(img_folder, ann_file, transforms=transforms) + + if image_set == "train": + dataset = _coco_remove_images_without_annotations(dataset) + + # dataset = torch.utils.data.Subset(dataset, [i for i in range(500)]) + + return dataset + + +def get_coco_kp(root, image_set, transforms): + return get_coco(root, image_set, transforms, mode="person_keypoints") diff --git a/PyTorch/objectDetection/maskrcnn/engine.py b/PyTorch/objectDetection/maskrcnn/engine.py new file mode 100644 index 00000000..86fb3e9b --- /dev/null +++ b/PyTorch/objectDetection/maskrcnn/engine.py @@ -0,0 +1,109 @@ +import math +import sys +import time +import torch + +import torchvision.models.detection.mask_rcnn + +from coco_utils import get_coco_api_from_dataset +from coco_eval import CocoEvaluator +import utils + + +def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq): + model.train() + metric_logger = utils.MetricLogger(delimiter=" ") + metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value:.6f}')) + header = 'Epoch: [{}]'.format(epoch) + + lr_scheduler = None + if epoch == 0: + warmup_factor = 1. / 1000 + warmup_iters = min(1000, len(data_loader) - 1) + + lr_scheduler = utils.warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor) + + for images, targets in metric_logger.log_every(data_loader, print_freq, header): + images = list(image.to(device) for image in images) + targets = [{k: v.to(device) for k, v in t.items()} for t in targets] + + loss_dict = model(images, targets) + + losses = sum(loss for loss in loss_dict.values()) + + # reduce losses over all GPUs for logging purposes + loss_dict_reduced = utils.reduce_dict(loss_dict) + losses_reduced = sum(loss for loss in loss_dict_reduced.values()) + + loss_value = losses_reduced.item() + + if not math.isfinite(loss_value): + print("Loss is {}, stopping training".format(loss_value)) + print(loss_dict_reduced) + sys.exit(1) + + optimizer.zero_grad() + losses.backward() + optimizer.step() + + if lr_scheduler is not None: + lr_scheduler.step() + + metric_logger.update(loss=losses_reduced, **loss_dict_reduced) + metric_logger.update(lr=optimizer.param_groups[0]["lr"]) + + return metric_logger + + +def _get_iou_types(model): + model_without_ddp = model + if isinstance(model, torch.nn.parallel.DistributedDataParallel): + model_without_ddp = model.module + iou_types = ["bbox"] + if isinstance(model_without_ddp, torchvision.models.detection.MaskRCNN): + iou_types.append("segm") + if isinstance(model_without_ddp, torchvision.models.detection.KeypointRCNN): + iou_types.append("keypoints") + return iou_types + + +@torch.no_grad() +def evaluate(model, data_loader, device): + n_threads = torch.get_num_threads() + # FIXME remove this and make paste_masks_in_image run on the GPU + torch.set_num_threads(1) + cpu_device = torch.device("cpu") + model.eval() + metric_logger = utils.MetricLogger(delimiter=" ") + header = 'Test:' + + coco = get_coco_api_from_dataset(data_loader.dataset) + iou_types = _get_iou_types(model) + coco_evaluator = CocoEvaluator(coco, iou_types) + + for images, targets in metric_logger.log_every(data_loader, 100, header): + images = list(img.to(device) for img in images) + + # torch.cuda.synchronize() + model_time = time.time() + outputs = model(images) + + outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs] + model_time = time.time() - model_time + + res = {target["image_id"].item(): output for target, output in zip(targets, outputs)} + evaluator_time = time.time() + coco_evaluator.update(res) + evaluator_time = time.time() - evaluator_time + metric_logger.update(model_time=model_time, evaluator_time=evaluator_time) + + # gather the stats from all processes + metric_logger.synchronize_between_processes() + print("Averaged stats:", metric_logger) + coco_evaluator.synchronize_between_processes() + + # accumulate predictions from all images + coco_evaluator.accumulate() + coco_evaluator.summarize() + torch.set_num_threads(n_threads) + return coco_evaluator diff --git a/PyTorch/objectDetection/maskrcnn/maskrcnn.py b/PyTorch/objectDetection/maskrcnn/maskrcnn.py new file mode 100644 index 00000000..45c9df9e --- /dev/null +++ b/PyTorch/objectDetection/maskrcnn/maskrcnn.py @@ -0,0 +1,172 @@ +import torch +import torch.utils.data + +import torchvision +from torchvision.models.detection.faster_rcnn import FastRCNNPredictor +from torchvision.models.detection.mask_rcnn import MaskRCNNPredictor + +from engine import train_one_epoch, evaluate +import utils +import transforms as T + +import os +import numpy as np +import torch +import torch.utils.data +from PIL import Image + +class PennFudanDataset(torch.utils.data.Dataset): + def __init__(self, root, transforms=None): + self.root = root + self.transforms = transforms + # load all image files, sorting them to + # ensure that they are aligned + self.imgs = list(sorted(os.listdir(os.path.join(root, "PNGImages")))) + self.masks = list(sorted(os.listdir(os.path.join(root, "PedMasks")))) + + def __getitem__(self, idx): + # load images ad masks + img_path = os.path.join(self.root, "PNGImages", self.imgs[idx]) + mask_path = os.path.join(self.root, "PedMasks", self.masks[idx]) + img = Image.open(img_path).convert("RGB") + # note that we haven't converted the mask to RGB, + # because each color corresponds to a different instance + # with 0 being background + mask = Image.open(mask_path) + + mask = np.array(mask) + # instances are encoded as different colors + obj_ids = np.unique(mask) + # first id is the background, so remove it + obj_ids = obj_ids[1:] + + # split the color-encoded mask into a set + # of binary masks + masks = mask == obj_ids[:, None, None] + + # get bounding box coordinates for each mask + num_objs = len(obj_ids) + boxes = [] + for i in range(num_objs): + pos = np.where(masks[i]) + xmin = np.min(pos[1]) + xmax = np.max(pos[1]) + ymin = np.min(pos[0]) + ymax = np.max(pos[0]) + boxes.append([xmin, ymin, xmax, ymax]) + + boxes = torch.as_tensor(boxes, dtype=torch.float32) + # there is only one class + labels = torch.ones((num_objs,), dtype=torch.int64) + masks = torch.as_tensor(masks, dtype=torch.uint8) + + image_id = torch.tensor([idx]) + area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0]) + # suppose all instances are not crowd + iscrowd = torch.zeros((num_objs,), dtype=torch.int64) + + target = {} + target["boxes"] = boxes + target["labels"] = labels + target["masks"] = masks + target["image_id"] = image_id + target["area"] = area + target["iscrowd"] = iscrowd + + if self.transforms is not None: + img, target = self.transforms(img, target) + + return img, target + + def __len__(self): + return len(self.imgs) + + +def get_instance_segmentation_model(num_classes): + # load an instance segmentation model pre-trained on COCO + model = torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True) + + # get the number of input features for the classifier + in_features = model.roi_heads.box_predictor.cls_score.in_features + # replace the pre-trained head with a new one + model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) + + # now get the number of input features for the mask classifier + in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels + hidden_layer = 256 + # and replace the mask predictor with a new one + model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, + hidden_layer, + num_classes) + + return model + + +def get_transform(train): + transforms = [] + # converts the image, a PIL image, into a PyTorch Tensor + transforms.append(T.ToTensor()) + if train: + # during training, randomly flip the training images + # and ground-truth for data augmentation + transforms.append(T.RandomHorizontalFlip(0.5)) + return T.Compose(transforms) + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser(__doc__) + parser.add_argument('--device', type=str, default='dml', help='The device to use for training.') + parser.add_argument('--batch_size', type=int, default=2, metavar='N', help='Batch size to train with.') + parser.add_argument('--epochs', type=int, default=10, metavar='N', help='The number of epochs to train for.') + + args = parser.parse_args() + + # model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True).to(args.device) + # use our dataset and defined transformations + dataset = PennFudanDataset('..\..\data\PennFudanPed', get_transform(train=True)) + dataset_test = PennFudanDataset('..\..\data\PennFudanPed', get_transform(train=False)) + + # split the dataset in train and test set + torch.manual_seed(1) + indices = torch.randperm(len(dataset)).tolist() + dataset = torch.utils.data.Subset(dataset, indices[:-50]) + dataset_test = torch.utils.data.Subset(dataset_test, indices[-50:]) + + # define training and validation data loaders + data_loader = torch.utils.data.DataLoader( + dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, + collate_fn=utils.collate_fn) + + data_loader_test = torch.utils.data.DataLoader( + dataset_test, batch_size=1, shuffle=False, num_workers=4, + collate_fn=utils.collate_fn) + + device = torch.device("dml") if args.device=='dml' else torch.device('cuda') if args.device=='cuda' else torch.device('cpu') + + # our dataset has two classes only - background and person + num_classes = 2 + + # get the model using our helper function + model = get_instance_segmentation_model(num_classes) + # move model to the right device + model.to(device) + + # construct an optimizer + params = [p for p in model.parameters() if p.requires_grad] + optimizer = torch.optim.SGD(params, lr=0.005, + momentum=0.9, weight_decay=0.0005) + + # and a learning rate scheduler which decreases the learning rate by + # 10x every 3 epochs + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, + step_size=3, + gamma=0.1) + + for epoch in range(args.epochs): + # train for one epoch, printing every 10 iterations + train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq=10) + # update the learning rate + lr_scheduler.step() + # evaluate on the test dataset + evaluate(model, data_loader_test, device=device) + diff --git a/PyTorch/objectDetection/maskrcnn/requirements.txt b/PyTorch/objectDetection/maskrcnn/requirements.txt new file mode 100644 index 00000000..fddd6e01 --- /dev/null +++ b/PyTorch/objectDetection/maskrcnn/requirements.txt @@ -0,0 +1,3 @@ +pytorch-directml +torchvision=0.9.0 +pycocotools \ No newline at end of file diff --git a/PyTorch/objectDetection/maskrcnn/transforms.py b/PyTorch/objectDetection/maskrcnn/transforms.py new file mode 100644 index 00000000..937ae3c0 --- /dev/null +++ b/PyTorch/objectDetection/maskrcnn/transforms.py @@ -0,0 +1,49 @@ +import random + +from torchvision.transforms import functional as F + + +def _flip_coco_person_keypoints(kps, width): + flip_inds = [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15] + flipped_data = kps[:, flip_inds] + flipped_data[..., 0] = width - flipped_data[..., 0] + # Maintain COCO convention that if visibility == 0, then x, y = 0 + inds = flipped_data[..., 2] == 0 + flipped_data[inds] = 0 + return flipped_data + + +class Compose(object): + def __init__(self, transforms): + self.transforms = transforms + + def __call__(self, image, target): + for t in self.transforms: + image, target = t(image, target) + return image, target + + +class RandomHorizontalFlip(object): + def __init__(self, prob): + self.prob = prob + + def __call__(self, image, target): + if random.random() < self.prob: + height, width = image.shape[-2:] + image = image.flip(-1) + bbox = target["boxes"] + bbox[:, [0, 2]] = width - bbox[:, [2, 0]] + target["boxes"] = bbox + if "masks" in target: + target["masks"] = target["masks"].flip(-1) + if "keypoints" in target: + keypoints = target["keypoints"] + keypoints = _flip_coco_person_keypoints(keypoints, width) + target["keypoints"] = keypoints + return image, target + + +class ToTensor(object): + def __call__(self, image, target): + image = F.to_tensor(image) + return image, target diff --git a/PyTorch/objectDetection/maskrcnn/utils.py b/PyTorch/objectDetection/maskrcnn/utils.py new file mode 100644 index 00000000..82ae79bc --- /dev/null +++ b/PyTorch/objectDetection/maskrcnn/utils.py @@ -0,0 +1,324 @@ +from collections import defaultdict, deque +import datetime +import pickle +import time + +import torch +import torch.distributed as dist + +import errno +import os + + +class SmoothedValue(object): + """Track a series of values and provide access to smoothed values over a + window or the global series average. + """ + + def __init__(self, window_size=20, fmt=None): + if fmt is None: + fmt = "{median:.4f} ({global_avg:.4f})" + self.deque = deque(maxlen=window_size) + self.total = 0.0 + self.count = 0 + self.fmt = fmt + + def update(self, value, n=1): + self.deque.append(value) + self.count += n + self.total += value * n + + def synchronize_between_processes(self): + """ + Warning: does not synchronize the deque! + """ + if not is_dist_avail_and_initialized(): + return + t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda') + dist.barrier() + dist.all_reduce(t) + t = t.tolist() + self.count = int(t[0]) + self.total = t[1] + + @property + def median(self): + d = torch.tensor(list(self.deque)) + return d.median().item() + + @property + def avg(self): + d = torch.tensor(list(self.deque), dtype=torch.float32) + return d.mean().item() + + @property + def global_avg(self): + return self.total / self.count + + @property + def max(self): + return max(self.deque) + + @property + def value(self): + return self.deque[-1] + + def __str__(self): + return self.fmt.format( + median=self.median, + avg=self.avg, + global_avg=self.global_avg, + max=self.max, + value=self.value) + + +def all_gather(data): + """ + Run all_gather on arbitrary picklable data (not necessarily tensors) + Args: + data: any picklable object + Returns: + list[data]: list of data gathered from each rank + """ + world_size = get_world_size() + if world_size == 1: + return [data] + + # serialized to a Tensor + buffer = pickle.dumps(data) + storage = torch.ByteStorage.from_buffer(buffer) + tensor = torch.ByteTensor(storage).to("cuda") + + # obtain Tensor size of each rank + local_size = torch.tensor([tensor.numel()], device="cuda") + size_list = [torch.tensor([0], device="cuda") for _ in range(world_size)] + dist.all_gather(size_list, local_size) + size_list = [int(size.item()) for size in size_list] + max_size = max(size_list) + + # receiving Tensor from all ranks + # we pad the tensor because torch all_gather does not support + # gathering tensors of different shapes + tensor_list = [] + for _ in size_list: + tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda")) + if local_size != max_size: + padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda") + tensor = torch.cat((tensor, padding), dim=0) + dist.all_gather(tensor_list, tensor) + + data_list = [] + for size, tensor in zip(size_list, tensor_list): + buffer = tensor.cpu().numpy().tobytes()[:size] + data_list.append(pickle.loads(buffer)) + + return data_list + + +def reduce_dict(input_dict, average=True): + """ + Args: + input_dict (dict): all the values will be reduced + average (bool): whether to do average or sum + Reduce the values in the dictionary from all processes so that all processes + have the averaged results. Returns a dict with the same fields as + input_dict, after reduction. + """ + world_size = get_world_size() + if world_size < 2: + return input_dict + with torch.no_grad(): + names = [] + values = [] + # sort the keys so that they are consistent across processes + for k in sorted(input_dict.keys()): + names.append(k) + values.append(input_dict[k]) + values = torch.stack(values, dim=0) + dist.all_reduce(values) + if average: + values /= world_size + reduced_dict = {k: v for k, v in zip(names, values)} + return reduced_dict + + +class MetricLogger(object): + def __init__(self, delimiter="\t"): + self.meters = defaultdict(SmoothedValue) + self.delimiter = delimiter + + def update(self, **kwargs): + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + v = v.item() + assert isinstance(v, (float, int)) + self.meters[k].update(v) + + def __getattr__(self, attr): + if attr in self.meters: + return self.meters[attr] + if attr in self.__dict__: + return self.__dict__[attr] + raise AttributeError("'{}' object has no attribute '{}'".format( + type(self).__name__, attr)) + + def __str__(self): + loss_str = [] + for name, meter in self.meters.items(): + loss_str.append( + "{}: {}".format(name, str(meter)) + ) + return self.delimiter.join(loss_str) + + def synchronize_between_processes(self): + for meter in self.meters.values(): + meter.synchronize_between_processes() + + def add_meter(self, name, meter): + self.meters[name] = meter + + def log_every(self, iterable, print_freq, header=None): + i = 0 + if not header: + header = '' + start_time = time.time() + end = time.time() + iter_time = SmoothedValue(fmt='{avg:.4f}') + data_time = SmoothedValue(fmt='{avg:.4f}') + space_fmt = ':' + str(len(str(len(iterable)))) + 'd' + if torch.cuda.is_available(): + log_msg = self.delimiter.join([ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}', + 'max mem: {memory:.0f}' + ]) + else: + log_msg = self.delimiter.join([ + header, + '[{0' + space_fmt + '}/{1}]', + 'eta: {eta}', + '{meters}', + 'time: {time}', + 'data: {data}' + ]) + MB = 1024.0 * 1024.0 + for obj in iterable: + data_time.update(time.time() - end) + yield obj + iter_time.update(time.time() - end) + if i % print_freq == 0 or i == len(iterable) - 1: + eta_seconds = iter_time.global_avg * (len(iterable) - i) + eta_string = str(datetime.timedelta(seconds=int(eta_seconds))) + if torch.cuda.is_available(): + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time), + memory=torch.cuda.max_memory_allocated() / MB)) + else: + print(log_msg.format( + i, len(iterable), eta=eta_string, + meters=str(self), + time=str(iter_time), data=str(data_time))) + i += 1 + end = time.time() + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + print('{} Total time: {} ({:.4f} s / it)'.format( + header, total_time_str, total_time / len(iterable))) + + +def collate_fn(batch): + return tuple(zip(*batch)) + + +def warmup_lr_scheduler(optimizer, warmup_iters, warmup_factor): + + def f(x): + if x >= warmup_iters: + return 1 + alpha = float(x) / warmup_iters + return warmup_factor * (1 - alpha) + alpha + + return torch.optim.lr_scheduler.LambdaLR(optimizer, f) + + +def mkdir(path): + try: + os.makedirs(path) + except OSError as e: + if e.errno != errno.EEXIST: + raise + + +def setup_for_distributed(is_master): + """ + This function disables printing when not in master process + """ + import builtins as __builtin__ + builtin_print = __builtin__.print + + def print(*args, **kwargs): + force = kwargs.pop('force', False) + if is_master or force: + builtin_print(*args, **kwargs) + + __builtin__.print = print + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def is_main_process(): + return get_rank() == 0 + + +def save_on_master(*args, **kwargs): + if is_main_process(): + torch.save(*args, **kwargs) + + +def init_distributed_mode(args): + if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ['WORLD_SIZE']) + args.gpu = int(os.environ['LOCAL_RANK']) + elif 'SLURM_PROCID' in os.environ: + args.rank = int(os.environ['SLURM_PROCID']) + args.gpu = args.rank % torch.cuda.device_count() + else: + print('Not using distributed mode') + args.distributed = False + return + + args.distributed = True + + torch.cuda.set_device(args.gpu) + args.dist_backend = 'nccl' + print('| distributed init (rank {}): {}'.format( + args.rank, args.dist_url), flush=True) + torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, + world_size=args.world_size, rank=args.rank) + torch.distributed.barrier() + setup_for_distributed(args.rank == 0) diff --git a/PyTorch/objectDetection/objectDetection.py b/PyTorch/objectDetection/objectDetection.py new file mode 100644 index 00000000..831b1aa5 --- /dev/null +++ b/PyTorch/objectDetection/objectDetection.py @@ -0,0 +1,72 @@ +import torch +import torch.autograd.profiler as profiler +import torchvision +import argparse + +device = torch.device('dml') + +object_detection_model_list = [ + 'fasterrcnn_resnet50_fpn', + 'fasterrcnn_mobilenet_v3_large_fpn', + 'fasterrcnn_mobilenet_v3_large_320_fpn', + 'retinanet_resnet50_fpn', + 'maskrcnn_resnet50_fpn', +] + +def get_model(model_str): + if model_str.lower() == 'fasterrcnn_resnet50_fpn': + return torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True) + elif model_str.lower() == 'fasterrcnn_mobilenet_v3_large_fpn': + return torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(pretrained=True) + elif model_str.lower() == 'fasterrcnn_mobilenet_v3_large_320_fpn': + return torchvision.models.detection.fasterrcnn_mobilenet_v3_large_320_fpn(pretrained=True) + elif model_str.lower() == 'retinanet_resnet50_fpn': + return torchvision.models.detection.retinanet_resnet50_fpn(pretrained=True) + elif model_str.lower() == 'maskrcnn_resnet50_fpn': + return torchvision.models.detection.maskrcnn_resnet50_fpn(pretrained=True) + else: + raise Exception(f"Model {model_str} is not supported yet!") + +if __name__ == '__main__': + parser = argparse.ArgumentParser(__doc__) + parser.add_argument('--model', type=str, default='fasterrcnn_resnet50_fpn', help='The model to use.') + parser.add_argument('--device', type=str, default='dml', help='The device to use for training.') + parser.add_argument('--save_trace', type=bool, default=False, help='Trace performance.') + + args = parser.parse_args() + + model = get_model(args.model).to(args.device) + + # construct an optimizer + params = [p for p in model.parameters() if p.requires_grad] + optimizer = torch.optim.SGD(params, lr=0.005, + momentum=0.9, weight_decay=0.0005) + + model.train() + + # generate garbage data for training one iteration + images, boxes = torch.rand(1, 3, 600, 1200).to(device), torch.rand(1, 11, 4).sort().values.to(device) + masks = torch.randint(0, 2, (1, 3, 600, 1200)).bool().to(device) + labels = torch.randint(1, 91, (1, 11)).to(device) + images = list(image for image in images) + targets = [] + + for i in range(len(images)): + d = {} + d['boxes'] = boxes[i] + d['labels'] = labels[i] + d['masks'] = masks[i] + targets.append(d) + + with profiler.profile(record_shapes=True, with_stack=True, profile_memory=True) as prof: + with profiler.record_function("model_inference"): + loss_dict = model(images, targets) + losses = sum(loss for loss in loss_dict.values()) + optimizer.zero_grad() + losses.backward() + optimizer.step() + + print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=1000)) + if args.save_trace: + trace_path = '_'.join([args.model, device, "trace.json"]) + prof.export_chrome_trace(trace_path) diff --git a/PyTorch/resnet50/README.md b/PyTorch/resnet50/README.md index eadc8e4a..c2a9c1f3 100644 --- a/PyTorch/resnet50/README.md +++ b/PyTorch/resnet50/README.md @@ -25,13 +25,13 @@ pip install -r pytorch\resnet50\requirements.txt After installing the PyTorch on DirectML package (see [GPU accelerated ML training](..\readme.md)), open a console to the `root` directory and run the setup script to download and convert data: ``` -python pytorch\data\cifar.py +python pytorch\data\dataset.py ``` Running `setup.py` should take at least a minute or so, since it downloads the CIFAR-10 dataset. The output of running it should look similar to the following: ``` ->python pytorch\data\cifar.py +>python pytorch\data\dataset.py Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to E:\work\dml\PyTorch\data\cifar-10-python\cifar-10-python.tar.gz Failed download. Trying https -> http instead. Downloading http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to E:\work\dml\PyTorch\data\cifar-10-python\cifar-10-python.tar.gz 170499072it [00:32, 5250164.09it/s] diff --git a/PyTorch/squeezenet/README.md b/PyTorch/squeezenet/README.md index 0f325bc9..adea3aa7 100644 --- a/PyTorch/squeezenet/README.md +++ b/PyTorch/squeezenet/README.md @@ -25,13 +25,13 @@ pip install -r pytorch\squeezenet\requirements.txt After installing the PyTorch on DirectML package (see [GPU accelerated ML training](..\readme.md)), open a console to the `root` directory and run the setup script to download and convert data: ``` -python pytorch\data\cifar.py +python pytorch\data\dataset.py ``` Running `setup.py` should take at least a minute or so, since it downloads the CIFAR-10 dataset. The output of running it should look similar to the following: ``` ->python pytorch\data\cifar.py +>python pytorch\data\dataset.py Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to E:\work\dml\PyTorch\data\cifar-10-python\cifar-10-python.tar.gz Failed download. Trying https -> http instead. Downloading http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to E:\work\dml\PyTorch\data\cifar-10-python\cifar-10-python.tar.gz 170499072it [00:32, 5250164.09it/s] diff --git a/PyTorch/torchvision_classification/README.md b/PyTorch/torchvision_classification/README.md index 63707f87..b4f37bfe 100644 --- a/PyTorch/torchvision_classification/README.md +++ b/PyTorch/torchvision_classification/README.md @@ -20,17 +20,21 @@ pip install -r pytorch\torchvision_classification\requirements.txt After installing the PyTorch on DirectML package (see [GPU accelerated ML training](..\readme.md)), open a console to the `root` directory and run the setup script to download and convert data: ``` -python pytorch\data\cifar.py +python pytorch\data\dataset.py ``` -Running `setup.py` should take at least a minute or so, since it downloads the CIFAR-10 dataset. The output of running it should look similar to the following: +Running `setup.py` should take at least a minute or so, since it downloads the CIFAR-10 dataset and PennFudanPed dataset. The output of running it should look similar to the following: ``` ->python pytorch\data\cifar.py -Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to E:\work\dml\PyTorch\data\cifar-10-python\cifar-10-python.tar.gz -Failed download. Trying https -> http instead. Downloading http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to E:\work\dml\PyTorch\data\cifar-10-python\cifar-10-python.tar.gz -170499072it [00:32, 5250164.09it/s] -Extracting E:\work\dml\PyTorch\data\cifar-10-python\cifar-10-python.tar.gz to E:\work\dml\PyTorch\data\cifar-10-python +>python pytorch\data\dataset.py +Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to E:\work\DirectML\PyTorch\data\cifar-10-python.tar.gz +Failed download. Trying https -> http instead. Downloading http://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to E:\work\DirectML\PyTorch\data\cifar-10-python.tar.gz +170499072it [00:17, 9709154.90it/s] +Extracting E:\work\DirectML\PyTorch\data\cifar-10-python.tar.gz to E:\work\DirectML\PyTorch\data +Downloading PennFundaPed dataset + +100% [........................................................................] 53723336 / 53723336 +Extracted PennFundaPed dataset ``` ## Training @@ -38,24 +42,9 @@ Extracting E:\work\dml\PyTorch\data\cifar-10-python\cifar-10-python.tar.gz to E: A helper script exists to train classification models with default data, batch size, and so on: ``` -python pytorch\torchvision_classification\train.py --model resnet18 -``` - -model names from list below can be used to train: -- resnet18 -- alexnet -- vgg16 -- squeezenet1_0 -- densenet161 -- inception_v3 -- googlenet -- shufflenet_v2_x1_0 -- mobilenet_v2 -- mobilenet_v3_large -- mobilenet_v3_small -- resnext50_32x4d -- wide_resnet50_2 -- mnasnet1_0 +cd pytorch\objectDetection\maskrcnn +python maskrcnn.py +``` The first few lines of output should look similar to the following (exact numbers may change): ``` From 726c700ed01fd20ffdf14ac6f6dd753b49f63970 Mon Sep 17 00:00:00 2001 From: Xiang Zhang Date: Thu, 7 Apr 2022 10:02:25 -0700 Subject: [PATCH 2/6] update requirements --- PyTorch/objectDetection/maskrcnn/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/PyTorch/objectDetection/maskrcnn/requirements.txt b/PyTorch/objectDetection/maskrcnn/requirements.txt index fddd6e01..aafa83b2 100644 --- a/PyTorch/objectDetection/maskrcnn/requirements.txt +++ b/PyTorch/objectDetection/maskrcnn/requirements.txt @@ -1,3 +1,3 @@ pytorch-directml -torchvision=0.9.0 +torchvision==0.9.0 pycocotools \ No newline at end of file From 684442ed4af2ad79d8e078398b40c9dec3cbb4e6 Mon Sep 17 00:00:00 2001 From: Xiang Zhang Date: Thu, 7 Apr 2022 10:08:00 -0700 Subject: [PATCH 3/6] revert miss-edited changes --- PyTorch/torchvision_classification/README.md | 21 +++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/PyTorch/torchvision_classification/README.md b/PyTorch/torchvision_classification/README.md index b4f37bfe..939e4866 100644 --- a/PyTorch/torchvision_classification/README.md +++ b/PyTorch/torchvision_classification/README.md @@ -42,9 +42,24 @@ Extracted PennFundaPed dataset A helper script exists to train classification models with default data, batch size, and so on: ``` -cd pytorch\objectDetection\maskrcnn -python maskrcnn.py -``` +python pytorch\torchvision_classification\train.py --model resnet18 +``` + +model names from list below can be used to train: +- resnet18 +- alexnet +- vgg16 +- squeezenet1_0 +- densenet161 +- inception_v3 +- googlenet +- shufflenet_v2_x1_0 +- mobilenet_v2 +- mobilenet_v3_large +- mobilenet_v3_small +- resnext50_32x4d +- wide_resnet50_2 +- mnasnet1_0 The first few lines of output should look similar to the following (exact numbers may change): ``` From c8fa0f34d0cc51f29ea302f43f8a9b9afc80330f Mon Sep 17 00:00:00 2001 From: Xiang Zhang Date: Mon, 25 Apr 2022 11:25:16 -0700 Subject: [PATCH 4/6] update samples --- PyTorch/objectDetection/maskrcnn/maskrcnn.py | 2 +- PyTorch/objectDetection/objectDetection.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/PyTorch/objectDetection/maskrcnn/maskrcnn.py b/PyTorch/objectDetection/maskrcnn/maskrcnn.py index 45c9df9e..326e8f16 100644 --- a/PyTorch/objectDetection/maskrcnn/maskrcnn.py +++ b/PyTorch/objectDetection/maskrcnn/maskrcnn.py @@ -141,7 +141,7 @@ def get_transform(train): dataset_test, batch_size=1, shuffle=False, num_workers=4, collate_fn=utils.collate_fn) - device = torch.device("dml") if args.device=='dml' else torch.device('cuda') if args.device=='cuda' else torch.device('cpu') + device = torch.device(args.device) # our dataset has two classes only - background and person num_classes = 2 diff --git a/PyTorch/objectDetection/objectDetection.py b/PyTorch/objectDetection/objectDetection.py index 831b1aa5..8c0539b3 100644 --- a/PyTorch/objectDetection/objectDetection.py +++ b/PyTorch/objectDetection/objectDetection.py @@ -68,5 +68,5 @@ def get_model(model_str): print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=1000)) if args.save_trace: - trace_path = '_'.join([args.model, device, "trace.json"]) + trace_path = '_'.join(["train", args.model, "dml", "trace.json"]) prof.export_chrome_trace(trace_path) From 3d65490324beccd74ad26397d2a7aac564f7ffe6 Mon Sep 17 00:00:00 2001 From: Xiang Zhang Date: Mon, 25 Apr 2022 15:58:54 -0700 Subject: [PATCH 5/6] update sampels --- PyTorch/README.md | 1 + PyTorch/objectDetection/maskrcnn/README.md | 4 ++-- PyTorch/objectDetection/maskrcnn/requirements.txt | 4 +++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/PyTorch/README.md b/PyTorch/README.md index da25641f..ba45dd0b 100644 --- a/PyTorch/README.md +++ b/PyTorch/README.md @@ -35,6 +35,7 @@ The following sample models are included in this repo to help you get started. T * [squeezenet - a small image classification model](./squeezenet) * [resnet50 - an image classification model](./resnet50) +* [maskrcnn - an object detection model](./objectDetection/maskrcnn/) * *more coming soon* ## External Links diff --git a/PyTorch/objectDetection/maskrcnn/README.md b/PyTorch/objectDetection/maskrcnn/README.md index f92031da..b0ac495a 100644 --- a/PyTorch/objectDetection/maskrcnn/README.md +++ b/PyTorch/objectDetection/maskrcnn/README.md @@ -22,7 +22,7 @@ After installing the PyTorch on DirectML package (see [GPU accelerated ML traini python pytorch\data\dataset.py ``` -Running `setup.py` should take at least a minute or so, since it downloads the CIFAR-10 dataset. The output of running it should look similar to the following: +Running `dataset.py` should take at least a minute or so, since it downloads the CIFAR-10 dataset. The output of running it should look similar to the following: ``` >python pytorch\data\dataset.py @@ -37,7 +37,7 @@ Extracting E:\work\dml\PyTorch\data\cifar-10-python\cifar-10-python.tar.gz to E: A helper script exists to train Mask R-CNN with PennFudanPed data: ``` - +cd pytorch\objectdetection\maskrcnn python .\maskrcnn.py ``` diff --git a/PyTorch/objectDetection/maskrcnn/requirements.txt b/PyTorch/objectDetection/maskrcnn/requirements.txt index aafa83b2..615fe472 100644 --- a/PyTorch/objectDetection/maskrcnn/requirements.txt +++ b/PyTorch/objectDetection/maskrcnn/requirements.txt @@ -1,3 +1,5 @@ pytorch-directml torchvision==0.9.0 -pycocotools \ No newline at end of file +pycocotools +wget +requests \ No newline at end of file From 316dfd58c3b5f9107f611df18db6e26d4d3362b0 Mon Sep 17 00:00:00 2001 From: Xiang Zhang Date: Mon, 25 Apr 2022 16:22:53 -0700 Subject: [PATCH 6/6] resolve comments --- PyTorch/.gitignore | 3 +-- PyTorch/objectDetection/maskrcnn/requirements.txt | 2 -- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/PyTorch/.gitignore b/PyTorch/.gitignore index c3c84194..21ffec9d 100644 --- a/PyTorch/.gitignore +++ b/PyTorch/.gitignore @@ -9,5 +9,4 @@ checkpoints coco128 data/PennFudanPed data/PennFudanPed.zip -objectDetection/PennFudanPed -checkpoints \ No newline at end of file +objectDetection/PennFudanPed \ No newline at end of file diff --git a/PyTorch/objectDetection/maskrcnn/requirements.txt b/PyTorch/objectDetection/maskrcnn/requirements.txt index 615fe472..de3640a0 100644 --- a/PyTorch/objectDetection/maskrcnn/requirements.txt +++ b/PyTorch/objectDetection/maskrcnn/requirements.txt @@ -1,5 +1,3 @@ -pytorch-directml -torchvision==0.9.0 pycocotools wget requests \ No newline at end of file