Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Commit

Permalink
Save labels in output directory (#842)
Browse files Browse the repository at this point in the history
* Merge branch 'master' of /home/braincreator/projects/maskrcnn-benchmark with conflicts.

* update Dockerfile

* save labels to output dir

* save labels on main process only
  • Loading branch information
Miguel Varela Ramos authored and fmassa committed Jun 3, 2019
1 parent d802413 commit 614e427
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 0 deletions.
5 changes: 5 additions & 0 deletions maskrcnn_benchmark/data/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch.utils.data
from maskrcnn_benchmark.utils.comm import get_world_size
from maskrcnn_benchmark.utils.imports import import_file
from maskrcnn_benchmark.utils.miscellaneous import save_labels

from . import datasets as D
from . import samplers
Expand Down Expand Up @@ -154,6 +155,10 @@ def make_data_loader(cfg, is_train=True, is_distributed=False, start_iter=0):
transforms = None if not is_train and cfg.TEST.BBOX_AUG.ENABLED else build_transforms(cfg, is_train)
datasets = build_dataset(dataset_list, transforms, DatasetCatalog, is_train)

if is_train:
# save category_id to label name mapping
save_labels(datasets, cfg.OUTPUT_DIR)

data_loaders = []
for dataset in datasets:
sampler = make_data_sampler(dataset, shuffle, is_distributed)
Expand Down
2 changes: 2 additions & 0 deletions maskrcnn_benchmark/data/datasets/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ def __init__(
ids.append(img_id)
self.ids = ids

self.categories = {cat['id']: cat['name'] for cat in self.coco.cats.values()}

self.json_category_id_to_contiguous_id = {
v: i + 1 for i, v in enumerate(self.coco.getCatIds())
}
Expand Down
1 change: 1 addition & 0 deletions maskrcnn_benchmark/data/datasets/voc.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(self, data_dir, split, use_difficult=False, transforms=None):

cls = PascalVOCDataset.CLASSES
self.class_to_ind = dict(zip(cls, range(len(cls))))
self.categories = dict(zip(range(len(cls)), cls))

def __getitem__(self, index):
img_id = self.ids[index]
Expand Down
21 changes: 21 additions & 0 deletions maskrcnn_benchmark/utils/miscellaneous.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import errno
import json
import logging
import os
from .comm import is_main_process

Expand All @@ -12,6 +14,25 @@ def mkdir(path):
raise


def save_labels(dataset_list, output_dir):
if is_main_process():
logger = logging.getLogger(__name__)

ids_to_labels = {}
for dataset in dataset_list:
if hasattr(dataset, 'categories'):
ids_to_labels.update(dataset.categories)
else:
logger.warning("Dataset [{}] has no categories attribute, labels.json file won't be created".format(
dataset.__class__.__name__))

if ids_to_labels:
labels_file = os.path.join(output_dir, 'labels.json')
logger.info("Saving labels mapping into {}".format(labels_file))
with open(labels_file, 'w') as f:
json.dump(ids_to_labels, f, indent=2)


def save_config(cfg, path):
if is_main_process():
with open(path, 'w') as f:
Expand Down

0 comments on commit 614e427

Please sign in to comment.