From 0f61b00483fadb476f03957b1c33158e4fc42ec7 Mon Sep 17 00:00:00 2001 From: Henry Wang Date: Thu, 6 Dec 2018 02:31:40 +0800 Subject: [PATCH] Support more datasets (#232) * add force json option * fix the same issue as #185 * bug fix * cityscapes config * update paths catalog * discard config change * organize code for more-datasets * use better representation for coco-style datasets * rename coco-style config * remove import * chmod 644 * make the config more verbose * update readme * rename * chmod --- ...e2e_faster_rcnn_R_50_FPN_1x_cocostyle.yaml | 32 +++ .../e2e_mask_rcnn_R_50_FPN_1x_cocostyle.yaml | 41 ++++ .../e2e_faster_rcnn_R_50_C4_1x_1_gpu_voc.yaml | 2 +- .../e2e_faster_rcnn_R_50_C4_1x_4_gpu_voc.yaml | 2 +- .../e2e_mask_rcnn_R_50_FPN_1x_cocostyle.yaml | 41 ++++ maskrcnn_benchmark/config/paths_catalog.py | 102 ++++++-- maskrcnn_benchmark/data/README.md | 88 +++++++ .../cityscapes/convert_cityscapes_to_coco.py | 221 ++++++++++++++++++ .../instances2dict_with_polygons.py | 79 +++++++ 9 files changed, 581 insertions(+), 27 deletions(-) create mode 100755 configs/cityscapes/e2e_faster_rcnn_R_50_FPN_1x_cocostyle.yaml create mode 100755 configs/cityscapes/e2e_mask_rcnn_R_50_FPN_1x_cocostyle.yaml create mode 100755 configs/pascal_voc/e2e_mask_rcnn_R_50_FPN_1x_cocostyle.yaml create mode 100644 maskrcnn_benchmark/data/README.md create mode 100644 tools/cityscapes/convert_cityscapes_to_coco.py create mode 100644 tools/cityscapes/instances2dict_with_polygons.py diff --git a/configs/cityscapes/e2e_faster_rcnn_R_50_FPN_1x_cocostyle.yaml b/configs/cityscapes/e2e_faster_rcnn_R_50_FPN_1x_cocostyle.yaml new file mode 100755 index 000000000..1df953877 --- /dev/null +++ b/configs/cityscapes/e2e_faster_rcnn_R_50_FPN_1x_cocostyle.yaml @@ -0,0 +1,32 @@ +MODEL: + META_ARCHITECTURE: "GeneralizedRCNN" + WEIGHT: "catalog://ImageNetPretrained/MSRA/R-50" + BACKBONE: + CONV_BODY: "R-50-FPN" + OUT_CHANNELS: 256 + RPN: + USE_FPN: True + ANCHOR_STRIDE: (4, 8, 16, 32, 64) + PRE_NMS_TOP_N_TRAIN: 2000 + PRE_NMS_TOP_N_TEST: 1000 + POST_NMS_TOP_N_TEST: 1000 + FPN_POST_NMS_TOP_N_TEST: 1000 + ROI_HEADS: + USE_FPN: True + ROI_BOX_HEAD: + POOLER_RESOLUTION: 7 + POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125) + POOLER_SAMPLING_RATIO: 2 + FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor" + PREDICTOR: "FPNPredictor" + NUM_CLASSES: 9 +DATASETS: + TRAIN: ("cityscapes_fine_instanceonly_seg_train_cocostyle",) + TEST: ("cityscapes_fine_instanceonly_seg_val_cocostyle",) +DATALOADER: + SIZE_DIVISIBILITY: 32 +SOLVER: + BASE_LR: 0.01 + WEIGHT_DECAY: 0.0001 + STEPS: (18000,) + MAX_ITER: 24000 diff --git a/configs/cityscapes/e2e_mask_rcnn_R_50_FPN_1x_cocostyle.yaml b/configs/cityscapes/e2e_mask_rcnn_R_50_FPN_1x_cocostyle.yaml new file mode 100755 index 000000000..fcaccaa94 --- /dev/null +++ b/configs/cityscapes/e2e_mask_rcnn_R_50_FPN_1x_cocostyle.yaml @@ -0,0 +1,41 @@ +MODEL: + META_ARCHITECTURE: "GeneralizedRCNN" + WEIGHT: "catalog://ImageNetPretrained/MSRA/R-50" + BACKBONE: + CONV_BODY: "R-50-FPN" + OUT_CHANNELS: 256 + RPN: + USE_FPN: True + ANCHOR_STRIDE: (4, 8, 16, 32, 64) + PRE_NMS_TOP_N_TRAIN: 2000 + PRE_NMS_TOP_N_TEST: 1000 + POST_NMS_TOP_N_TEST: 1000 + FPN_POST_NMS_TOP_N_TEST: 1000 + ROI_HEADS: + USE_FPN: True + ROI_BOX_HEAD: + POOLER_RESOLUTION: 7 + POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125) + POOLER_SAMPLING_RATIO: 2 + FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor" + PREDICTOR: "FPNPredictor" + NUM_CLASSES: 9 + ROI_MASK_HEAD: + POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125) + FEATURE_EXTRACTOR: "MaskRCNNFPNFeatureExtractor" + PREDICTOR: "MaskRCNNC4Predictor" + POOLER_RESOLUTION: 14 + POOLER_SAMPLING_RATIO: 2 + RESOLUTION: 28 + SHARE_BOX_FEATURE_EXTRACTOR: False + MASK_ON: True +DATASETS: + TRAIN: ("cityscapes_fine_instanceonly_seg_train_cocostyle",) + TEST: ("cityscapes_fine_instanceonly_seg_val_cocostyle",) +DATALOADER: + SIZE_DIVISIBILITY: 32 +SOLVER: + BASE_LR: 0.01 + WEIGHT_DECAY: 0.0001 + STEPS: (18000,) + MAX_ITER: 24000 diff --git a/configs/pascal_voc/e2e_faster_rcnn_R_50_C4_1x_1_gpu_voc.yaml b/configs/pascal_voc/e2e_faster_rcnn_R_50_C4_1x_1_gpu_voc.yaml index 1a5ad6a24..77dea5b45 100644 --- a/configs/pascal_voc/e2e_faster_rcnn_R_50_C4_1x_1_gpu_voc.yaml +++ b/configs/pascal_voc/e2e_faster_rcnn_R_50_C4_1x_1_gpu_voc.yaml @@ -8,7 +8,7 @@ MODEL: ROI_BOX_HEAD: NUM_CLASSES: 21 DATASETS: - TRAIN: ("voc_2007_trainval",) + TRAIN: ("voc_2007_train", "voc_2007_val") TEST: ("voc_2007_test",) SOLVER: BASE_LR: 0.001 diff --git a/configs/pascal_voc/e2e_faster_rcnn_R_50_C4_1x_4_gpu_voc.yaml b/configs/pascal_voc/e2e_faster_rcnn_R_50_C4_1x_4_gpu_voc.yaml index ea8dfe7d4..a8fb663a4 100644 --- a/configs/pascal_voc/e2e_faster_rcnn_R_50_C4_1x_4_gpu_voc.yaml +++ b/configs/pascal_voc/e2e_faster_rcnn_R_50_C4_1x_4_gpu_voc.yaml @@ -8,7 +8,7 @@ MODEL: ROI_BOX_HEAD: NUM_CLASSES: 21 DATASETS: - TRAIN: ("voc_2007_trainval",) + TRAIN: ("voc_2007_train", "voc_2007_val") TEST: ("voc_2007_test",) SOLVER: BASE_LR: 0.004 diff --git a/configs/pascal_voc/e2e_mask_rcnn_R_50_FPN_1x_cocostyle.yaml b/configs/pascal_voc/e2e_mask_rcnn_R_50_FPN_1x_cocostyle.yaml new file mode 100755 index 000000000..e6f29e88d --- /dev/null +++ b/configs/pascal_voc/e2e_mask_rcnn_R_50_FPN_1x_cocostyle.yaml @@ -0,0 +1,41 @@ +MODEL: + META_ARCHITECTURE: "GeneralizedRCNN" + WEIGHT: "catalog://ImageNetPretrained/MSRA/R-50" + BACKBONE: + CONV_BODY: "R-50-FPN" + OUT_CHANNELS: 256 + RPN: + USE_FPN: True + ANCHOR_STRIDE: (4, 8, 16, 32, 64) + PRE_NMS_TOP_N_TRAIN: 2000 + PRE_NMS_TOP_N_TEST: 1000 + POST_NMS_TOP_N_TEST: 1000 + FPN_POST_NMS_TOP_N_TEST: 1000 + ROI_HEADS: + USE_FPN: True + ROI_BOX_HEAD: + POOLER_RESOLUTION: 7 + POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125) + POOLER_SAMPLING_RATIO: 2 + FEATURE_EXTRACTOR: "FPN2MLPFeatureExtractor" + PREDICTOR: "FPNPredictor" + NUM_CLASSES: 21 + ROI_MASK_HEAD: + POOLER_SCALES: (0.25, 0.125, 0.0625, 0.03125) + FEATURE_EXTRACTOR: "MaskRCNNFPNFeatureExtractor" + PREDICTOR: "MaskRCNNC4Predictor" + POOLER_RESOLUTION: 14 + POOLER_SAMPLING_RATIO: 2 + RESOLUTION: 28 + SHARE_BOX_FEATURE_EXTRACTOR: False + MASK_ON: True +DATASETS: + TRAIN: ("voc_2012_train_cocostyle",) + TEST: ("voc_2012_val_cocostyle",) +DATALOADER: + SIZE_DIVISIBILITY: 32 +SOLVER: + BASE_LR: 0.01 + WEIGHT_DECAY: 0.0001 + STEPS: (18000,) + MAX_ITER: 24000 diff --git a/maskrcnn_benchmark/config/paths_catalog.py b/maskrcnn_benchmark/config/paths_catalog.py index c165105a7..53a425f74 100644 --- a/maskrcnn_benchmark/config/paths_catalog.py +++ b/maskrcnn_benchmark/config/paths_catalog.py @@ -6,28 +6,80 @@ class DatasetCatalog(object): DATA_DIR = "datasets" - DATASETS = { - "coco_2014_train": ( - "coco/train2014", - "coco/annotations/instances_train2014.json", - ), - "coco_2014_val": ("coco/val2014", "coco/annotations/instances_val2014.json"), - "coco_2014_minival": ( - "coco/val2014", - "coco/annotations/instances_minival2014.json", - ), - "coco_2014_valminusminival": ( - "coco/val2014", - "coco/annotations/instances_valminusminival2014.json", - ), - "voc_2007_trainval": ("voc/VOC2007", 'trainval'), - "voc_2007_test": ("voc/VOC2007", 'test'), - "voc_2012_train": ("voc/VOC2012", 'train'), - "voc_2012_trainval": ("voc/VOC2012", 'trainval'), - "voc_2012_val": ("voc/VOC2012", 'val'), - "voc_2012_test": ("voc/VOC2012", 'test'), - + "coco_2014_train": { + "img_dir": "coco/train2014", + "ann_file": "coco/annotations/instances_train2014.json" + }, + "coco_2014_val": { + "img_dir": "coco/val2014", + "ann_file": "coco/annotations/instances_val2014.json" + }, + "coco_2014_minival": { + "img_dir": "coco/val2014", + "ann_file": "coco/annotations/instances_minival2014.json" + }, + "coco_2014_valminusminival": { + "img_dir": "coco/val2014", + "ann_file": "coco/annotations/instances_valminusminival2014.json" + }, + "voc_2007_train": { + "data_dir": "voc/VOC2007", + "split": "train" + }, + "voc_2007_train_cocostyle": { + "img_dir": "voc/VOC2007/JPEGImages", + "ann_file": "voc/VOC2007/Annotations/pascal_train2007.json" + }, + "voc_2007_val": { + "data_dir": "voc/VOC2007", + "split": "val" + }, + "voc_2007_val_cocostyle": { + "img_dir": "voc/VOC2007/JPEGImages", + "ann_file": "voc/VOC2007/Annotations/pascal_val2007.json" + }, + "voc_2007_test": { + "data_dir": "voc/VOC2007", + "split": "test" + }, + "voc_2007_test_cocostyle": { + "img_dir": "voc/VOC2007/JPEGImages", + "ann_file": "voc/VOC2007/Annotations/pascal_test2007.json" + }, + "voc_2012_train": { + "data_dir": "voc/VOC2012", + "split": "train" + }, + "voc_2012_train_cocostyle": { + "img_dir": "voc/VOC2012/JPEGImages", + "ann_file": "voc/VOC2012/Annotations/pascal_train2012.json" + }, + "voc_2012_val": { + "data_dir": "voc/VOC2012", + "split": "val" + }, + "voc_2012_val_cocostyle": { + "img_dir": "voc/VOC2012/JPEGImages", + "ann_file": "voc/VOC2012/Annotations/pascal_val2012.json" + }, + "voc_2012_test": { + "data_dir": "voc/VOC2012", + "split": "test" + # PASCAL VOC2012 doesn't made the test annotations available, so there's no json annotation + }, + "cityscapes_fine_instanceonly_seg_train_cocostyle": { + "img_dir": "cityscapes/images", + "ann_file": "cityscapes/annotations/instancesonly_filtered_gtFine_train.json" + }, + "cityscapes_fine_instanceonly_seg_val_cocostyle": { + "img_dir": "cityscapes/images", + "ann_file": "cityscapes/annotations/instancesonly_filtered_gtFine_val.json" + }, + "cityscapes_fine_instanceonly_seg_test_cocostyle": { + "img_dir": "cityscapes/images", + "ann_file": "cityscapes/annotations/instancesonly_filtered_gtFine_test.json" + } } @staticmethod @@ -36,8 +88,8 @@ def get(name): data_dir = DatasetCatalog.DATA_DIR attrs = DatasetCatalog.DATASETS[name] args = dict( - root=os.path.join(data_dir, attrs[0]), - ann_file=os.path.join(data_dir, attrs[1]), + root=os.path.join(data_dir, attrs["img_dir"]), + ann_file=os.path.join(data_dir, attrs["ann_file"]), ) return dict( factory="COCODataset", @@ -47,8 +99,8 @@ def get(name): data_dir = DatasetCatalog.DATA_DIR attrs = DatasetCatalog.DATASETS[name] args = dict( - data_dir=os.path.join(data_dir, attrs[0]), - split=attrs[1], + data_dir=os.path.join(data_dir, attrs["data_dir"]), + split=attrs["split"], ) return dict( factory="PascalVOCDataset", diff --git a/maskrcnn_benchmark/data/README.md b/maskrcnn_benchmark/data/README.md new file mode 100644 index 000000000..7d6e127ad --- /dev/null +++ b/maskrcnn_benchmark/data/README.md @@ -0,0 +1,88 @@ +# Setting Up Datasets +This file describes how to perform training on other datasets. + +Only Pascal VOC dataset can be loaded from its original format and be outputted to Pascal style results currently. + +We expect the annotations from other datasets be converted to COCO json format, and +the output will be in COCO-style. (i.e. AP, AP50, AP75, APs, APm, APl for bbox and segm) + +## Creating Symlinks for PASCAL VOC + +We assume that your symlinked `datasets/voc/VOC` directory has the following structure: + +``` +VOC +|_ JPEGImages +| |_ .jpg +| |_ ... +| |_ .jpg +|_ Annotations +| |_ pascal_train.json (optional) +| |_ pascal_val.json (optional) +| |_ pascal_test.json (optional) +| |_ .xml +| |_ ... +| |_ .xml +|_ VOCdevkit +``` + +Create symlinks for `voc/VOC`: + +``` +cd ~/github/maskrcnn-benchmark +mkdir -p datasets/voc/VOC +ln -s /path/to/VOC /datasets/voc/VOC +``` +Example configuration files for PASCAL VOC could be found [here](https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/configs/pascal_voc/). + +### PASCAL VOC Annotations in COCO Format +To output COCO-style evaluation result, PASCAL VOC annotations in COCO json format is required and could be downloaded from [here](https://storage.googleapis.com/coco-dataset/external/PASCAL_VOC.zip) +via http://cocodataset.org/#external. + +## Creating Symlinks for Cityscapes: + +We assume that your symlinked `datasets/cityscapes` directory has the following structure: + +``` +cityscapes +|_ images +| |_ .jpg +| |_ ... +| |_ .jpg +|_ annotations +| |_ instanceonly_gtFile_train.json +| |_ ... +|_ raw + |_ gtFine + |_ ... + |_ README.md +``` + +Create symlinks for `cityscapes`: + +``` +cd ~/github/maskrcnn-benchmark +mkdir -p datasets/cityscapes +ln -s /path/to/cityscapes datasets/data/cityscapes +``` + +### Steps to convert Cityscapes Annotations to COCO Format +1. Download gtFine_trainvaltest.zip from https://www.cityscapes-dataset.com/downloads/ (login required) +2. Extract it to /path/to/gtFine_trainvaltest +``` +gtFine_trainvaltest +|_ gtFine +``` +3. Run the below commands to convert the annotations + +``` +cd ~/github +git clone https://github.com/mcordts/cityscapesScripts.git +cd cityscapesScripts +cp ~/github/maskrcnn-benchmark/tool/cityscapes/instances2dict_with_polygons.py cityscapesscripts/evaluation +python setup.py install +cd ~/github/maskrcnn-benchmark +python tools/cityscapes/convert_cityscapes_to_coco.py --datadir /path/to/gtFine_trainvaltest --outdir /path/to/cityscapes/annotations +``` + +Example configuration files for Cityscapes could be found [here](https://github.com/facebookresearch/maskrcnn-benchmark/blob/master/configs/cityscapes/). diff --git a/tools/cityscapes/convert_cityscapes_to_coco.py b/tools/cityscapes/convert_cityscapes_to_coco.py new file mode 100644 index 000000000..b4ddd9edf --- /dev/null +++ b/tools/cityscapes/convert_cityscapes_to_coco.py @@ -0,0 +1,221 @@ +#!/usr/bin/env python + +# Copyright (c) 2017-present, Facebook, Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +############################################################################## + +# This file is copy from https://github.com/facebookresearch/Detectron/tree/master/tools + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + +import argparse +import h5py +import json +import os +import scipy.misc +import sys + +import cityscapesscripts.evaluation.instances2dict_with_polygons as cs + +import detectron.utils.segms as segms_util +import detectron.utils.boxes as bboxs_util + + +def parse_args(): + parser = argparse.ArgumentParser(description='Convert dataset') + parser.add_argument( + '--dataset', help="cocostuff, cityscapes", default=None, type=str) + parser.add_argument( + '--outdir', help="output dir for json files", default=None, type=str) + parser.add_argument( + '--datadir', help="data dir for annotations to be converted", + default=None, type=str) + if len(sys.argv) == 1: + parser.print_help() + sys.exit(1) + return parser.parse_args() + + +def convert_coco_stuff_mat(data_dir, out_dir): + """Convert to png and save json with path. This currently only contains + the segmentation labels for objects+stuff in cocostuff - if we need to + combine with other labels from original COCO that will be a TODO.""" + sets = ['train', 'val'] + categories = [] + json_name = 'coco_stuff_%s.json' + ann_dict = {} + for data_set in sets: + file_list = os.path.join(data_dir, '%s.txt') + images = [] + with open(file_list % data_set) as f: + for img_id, img_name in enumerate(f): + img_name = img_name.replace('coco', 'COCO').strip('\n') + image = {} + mat_file = os.path.join( + data_dir, 'annotations/%s.mat' % img_name) + data = h5py.File(mat_file, 'r') + labelMap = data.get('S') + if len(categories) == 0: + labelNames = data.get('names') + for idx, n in enumerate(labelNames): + categories.append( + {"id": idx, "name": ''.join(chr(i) for i in data[ + n[0]])}) + ann_dict['categories'] = categories + scipy.misc.imsave( + os.path.join(data_dir, img_name + '.png'), labelMap) + image['width'] = labelMap.shape[0] + image['height'] = labelMap.shape[1] + image['file_name'] = img_name + image['seg_file_name'] = img_name + image['id'] = img_id + images.append(image) + ann_dict['images'] = images + print("Num images: %s" % len(images)) + with open(os.path.join(out_dir, json_name % data_set), 'wb') as outfile: + outfile.write(json.dumps(ann_dict)) + + +# for Cityscapes +def getLabelID(self, instID): + if (instID < 1000): + return instID + else: + return int(instID / 1000) + + +def convert_cityscapes_instance_only( + data_dir, out_dir): + """Convert from cityscapes format to COCO instance seg format - polygons""" + sets = [ + 'gtFine_val', + 'gtFine_train', + 'gtFine_test', + + # 'gtCoarse_train', + # 'gtCoarse_val', + # 'gtCoarse_train_extra' + ] + ann_dirs = [ + 'gtFine_trainvaltest/gtFine/val', + 'gtFine_trainvaltest/gtFine/train', + 'gtFine_trainvaltest/gtFine/test', + + # 'gtCoarse/train', + # 'gtCoarse/train_extra', + # 'gtCoarse/val' + ] + json_name = 'instancesonly_filtered_%s.json' + ends_in = '%s_polygons.json' + img_id = 0 + ann_id = 0 + cat_id = 1 + category_dict = {} + + category_instancesonly = [ + 'person', + 'rider', + 'car', + 'truck', + 'bus', + 'train', + 'motorcycle', + 'bicycle', + ] + + for data_set, ann_dir in zip(sets, ann_dirs): + print('Starting %s' % data_set) + ann_dict = {} + images = [] + annotations = [] + ann_dir = os.path.join(data_dir, ann_dir) + for root, _, files in os.walk(ann_dir): + for filename in files: + if filename.endswith(ends_in % data_set.split('_')[0]): + if len(images) % 50 == 0: + print("Processed %s images, %s annotations" % ( + len(images), len(annotations))) + json_ann = json.load(open(os.path.join(root, filename))) + image = {} + image['id'] = img_id + img_id += 1 + + image['width'] = json_ann['imgWidth'] + image['height'] = json_ann['imgHeight'] + image['file_name'] = filename[:-len( + ends_in % data_set.split('_')[0])] + 'leftImg8bit.png' + image['seg_file_name'] = filename[:-len( + ends_in % data_set.split('_')[0])] + \ + '%s_instanceIds.png' % data_set.split('_')[0] + images.append(image) + + fullname = os.path.join(root, image['seg_file_name']) + objects = cs.instances2dict_with_polygons( + [fullname], verbose=False)[fullname] + + for object_cls in objects: + if object_cls not in category_instancesonly: + continue # skip non-instance categories + + for obj in objects[object_cls]: + if obj['contours'] == []: + print('Warning: empty contours.') + continue # skip non-instance categories + + len_p = [len(p) for p in obj['contours']] + if min(len_p) <= 4: + print('Warning: invalid contours.') + continue # skip non-instance categories + + ann = {} + ann['id'] = ann_id + ann_id += 1 + ann['image_id'] = image['id'] + ann['segmentation'] = obj['contours'] + + if object_cls not in category_dict: + category_dict[object_cls] = cat_id + cat_id += 1 + ann['category_id'] = category_dict[object_cls] + ann['iscrowd'] = 0 + ann['area'] = obj['pixelCount'] + ann['bbox'] = bboxs_util.xyxy_to_xywh( + segms_util.polys_to_boxes( + [ann['segmentation']])).tolist()[0] + + annotations.append(ann) + + ann_dict['images'] = images + categories = [{"id": category_dict[name], "name": name} for name in + category_dict] + ann_dict['categories'] = categories + ann_dict['annotations'] = annotations + print("Num categories: %s" % len(categories)) + print("Num images: %s" % len(images)) + print("Num annotations: %s" % len(annotations)) + with open(os.path.join(out_dir, json_name % data_set), 'w') as outfile: + outfile.write(json.dumps(ann_dict)) + + +if __name__ == '__main__': + args = parse_args() + if args.dataset == "cityscapes_instance_only": + convert_cityscapes_instance_only(args.datadir, args.outdir) + elif args.dataset == "cocostuff": + convert_coco_stuff_mat(args.datadir, args.outdir) + else: + print("Dataset not supported: %s" % args.dataset) diff --git a/tools/cityscapes/instances2dict_with_polygons.py b/tools/cityscapes/instances2dict_with_polygons.py new file mode 100644 index 000000000..fbfc8d13b --- /dev/null +++ b/tools/cityscapes/instances2dict_with_polygons.py @@ -0,0 +1,79 @@ +#!/usr/bin/python +# +# Convert instances from png files to a dictionary +# This files is created according to https://github.com/facebookresearch/Detectron/issues/111 + +from __future__ import print_function, absolute_import, division +import os, sys + +sys.path.append( os.path.normpath( os.path.join( os.path.dirname( __file__ ) , '..' , 'helpers' ) ) ) +from csHelpers import * + +# Cityscapes imports +from cityscapesscripts.evaluation.instance import * +from cityscapesscripts.helpers.csHelpers import * +import cv2 + +def instances2dict_with_polygons(imageFileList, verbose=False): + imgCount = 0 + instanceDict = {} + + if not isinstance(imageFileList, list): + imageFileList = [imageFileList] + + if verbose: + print("Processing {} images...".format(len(imageFileList))) + + for imageFileName in imageFileList: + # Load image + img = Image.open(imageFileName) + + # Image as numpy array + imgNp = np.array(img) + + # Initialize label categories + instances = {} + for label in labels: + instances[label.name] = [] + + # Loop through all instance ids in instance image + for instanceId in np.unique(imgNp): + if instanceId < 1000: + continue + instanceObj = Instance(imgNp, instanceId) + instanceObj_dict = instanceObj.toDict() + + #instances[id2label[instanceObj.labelID].name].append(instanceObj.toDict()) + if id2label[instanceObj.labelID].hasInstances: + mask = (imgNp == instanceId).astype(np.uint8) + im2, contour, hier = cv2.findContours( + mask.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) + + polygons = [c.reshape(-1).tolist() for c in contour] + instanceObj_dict['contours'] = polygons + + instances[id2label[instanceObj.labelID].name].append(instanceObj_dict) + + imgKey = os.path.abspath(imageFileName) + instanceDict[imgKey] = instances + imgCount += 1 + + if verbose: + print("\rImages Processed: {}".format(imgCount), end=' ') + sys.stdout.flush() + + if verbose: + print("") + + return instanceDict + +def main(argv): + fileList = [] + if (len(argv) > 2): + for arg in argv: + if ("png" in arg): + fileList.append(arg) + instances2dict_with_polygons(fileList, True) + +if __name__ == "__main__": + main(sys.argv[1:])