From 614e42759652e66103c21808e82265b17cc97214 Mon Sep 17 00:00:00 2001 From: Miguel Varela Ramos Date: Mon, 3 Jun 2019 11:49:17 +0200 Subject: [PATCH] Save labels in output directory (#842) * Merge branch 'master' of /home/braincreator/projects/maskrcnn-benchmark with conflicts. * update Dockerfile * save labels to output dir * save labels on main process only --- maskrcnn_benchmark/data/build.py | 5 +++++ maskrcnn_benchmark/data/datasets/coco.py | 2 ++ maskrcnn_benchmark/data/datasets/voc.py | 1 + maskrcnn_benchmark/utils/miscellaneous.py | 21 +++++++++++++++++++++ 4 files changed, 29 insertions(+) diff --git a/maskrcnn_benchmark/data/build.py b/maskrcnn_benchmark/data/build.py index b0ce3c348..8cf610ade 100644 --- a/maskrcnn_benchmark/data/build.py +++ b/maskrcnn_benchmark/data/build.py @@ -6,6 +6,7 @@ import torch.utils.data from maskrcnn_benchmark.utils.comm import get_world_size from maskrcnn_benchmark.utils.imports import import_file +from maskrcnn_benchmark.utils.miscellaneous import save_labels from . import datasets as D from . import samplers @@ -154,6 +155,10 @@ def make_data_loader(cfg, is_train=True, is_distributed=False, start_iter=0): transforms = None if not is_train and cfg.TEST.BBOX_AUG.ENABLED else build_transforms(cfg, is_train) datasets = build_dataset(dataset_list, transforms, DatasetCatalog, is_train) + if is_train: + # save category_id to label name mapping + save_labels(datasets, cfg.OUTPUT_DIR) + data_loaders = [] for dataset in datasets: sampler = make_data_sampler(dataset, shuffle, is_distributed) diff --git a/maskrcnn_benchmark/data/datasets/coco.py b/maskrcnn_benchmark/data/datasets/coco.py index cd9fc835e..cc10f29d1 100644 --- a/maskrcnn_benchmark/data/datasets/coco.py +++ b/maskrcnn_benchmark/data/datasets/coco.py @@ -54,6 +54,8 @@ def __init__( ids.append(img_id) self.ids = ids + self.categories = {cat['id']: cat['name'] for cat in self.coco.cats.values()} + self.json_category_id_to_contiguous_id = { v: i + 1 for i, v in enumerate(self.coco.getCatIds()) } diff --git a/maskrcnn_benchmark/data/datasets/voc.py b/maskrcnn_benchmark/data/datasets/voc.py index ad20a8721..ab4075ec5 100644 --- a/maskrcnn_benchmark/data/datasets/voc.py +++ b/maskrcnn_benchmark/data/datasets/voc.py @@ -57,6 +57,7 @@ def __init__(self, data_dir, split, use_difficult=False, transforms=None): cls = PascalVOCDataset.CLASSES self.class_to_ind = dict(zip(cls, range(len(cls)))) + self.categories = dict(zip(range(len(cls)), cls)) def __getitem__(self, index): img_id = self.ids[index] diff --git a/maskrcnn_benchmark/utils/miscellaneous.py b/maskrcnn_benchmark/utils/miscellaneous.py index ecd3ef6a2..ce1c279bf 100644 --- a/maskrcnn_benchmark/utils/miscellaneous.py +++ b/maskrcnn_benchmark/utils/miscellaneous.py @@ -1,5 +1,7 @@ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. import errno +import json +import logging import os from .comm import is_main_process @@ -12,6 +14,25 @@ def mkdir(path): raise +def save_labels(dataset_list, output_dir): + if is_main_process(): + logger = logging.getLogger(__name__) + + ids_to_labels = {} + for dataset in dataset_list: + if hasattr(dataset, 'categories'): + ids_to_labels.update(dataset.categories) + else: + logger.warning("Dataset [{}] has no categories attribute, labels.json file won't be created".format( + dataset.__class__.__name__)) + + if ids_to_labels: + labels_file = os.path.join(output_dir, 'labels.json') + logger.info("Saving labels mapping into {}".format(labels_file)) + with open(labels_file, 'w') as f: + json.dump(ids_to_labels, f, indent=2) + + def save_config(cfg, path): if is_main_process(): with open(path, 'w') as f: