Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Save labels in output directory #842

Merged
merged 8 commits into from
Jun 3, 2019
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions maskrcnn_benchmark/data/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch.utils.data
from maskrcnn_benchmark.utils.comm import get_world_size
from maskrcnn_benchmark.utils.imports import import_file
from maskrcnn_benchmark.utils.miscellaneous import save_labels

from . import datasets as D
from . import samplers
Expand Down Expand Up @@ -154,6 +155,10 @@ def make_data_loader(cfg, is_train=True, is_distributed=False, start_iter=0):
transforms = None if not is_train and cfg.TEST.BBOX_AUG.ENABLED else build_transforms(cfg, is_train)
datasets = build_dataset(dataset_list, transforms, DatasetCatalog, is_train)

if is_train:
# save category_id to label name mapping
save_labels(datasets, cfg.OUTPUT_DIR)

data_loaders = []
for dataset in datasets:
sampler = make_data_sampler(dataset, shuffle, is_distributed)
Expand Down
2 changes: 2 additions & 0 deletions maskrcnn_benchmark/data/datasets/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ def __init__(
ids.append(img_id)
self.ids = ids

self.categories = {cat['id']: cat['name'] for cat in self.coco.cats.values()}

self.json_category_id_to_contiguous_id = {
v: i + 1 for i, v in enumerate(self.coco.getCatIds())
}
Expand Down
1 change: 1 addition & 0 deletions maskrcnn_benchmark/data/datasets/voc.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(self, data_dir, split, use_difficult=False, transforms=None):

cls = PascalVOCDataset.CLASSES
self.class_to_ind = dict(zip(cls, range(len(cls))))
self.categories = dict(zip(range(len(cls)), cls))

def __getitem__(self, index):
img_id = self.ids[index]
Expand Down
22 changes: 22 additions & 0 deletions maskrcnn_benchmark/utils/miscellaneous.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import errno
import json
import logging
import os
from .comm import is_main_process


def mkdir(path):
Expand All @@ -9,3 +12,22 @@ def mkdir(path):
except OSError as e:
if e.errno != errno.EEXIST:
raise


def save_labels(dataset_list, output_dir):
if is_main_process():
logger = logging.getLogger(__name__)

ids_to_labels = {}
for dataset in dataset_list:
if hasattr(dataset, 'categories'):
ids_to_labels.update(dataset.categories)
else:
logger.warning("Dataset [{}] has no categories attribute, labels.json file won't be created".format(
dataset.__class__.__name__))

if ids_to_labels:
labels_file = os.path.join(output_dir, 'labels.json')
logger.info("Saving labels mapping into {}".format(labels_file))
with open(labels_file, 'w') as f:
json.dump(ids_to_labels, f, indent=2)