Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Save labels in output directory #842

Merged
merged 8 commits into from
Jun 3, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions maskrcnn_benchmark/data/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch.utils.data
from maskrcnn_benchmark.utils.comm import get_world_size
from maskrcnn_benchmark.utils.imports import import_file
from maskrcnn_benchmark.utils.miscellaneous import save_labels

from . import datasets as D
from . import samplers
Expand Down Expand Up @@ -154,6 +155,10 @@ def make_data_loader(cfg, is_train=True, is_distributed=False, start_iter=0):
transforms = None if not is_train and cfg.TEST.BBOX_AUG.ENABLED else build_transforms(cfg, is_train)
datasets = build_dataset(dataset_list, transforms, DatasetCatalog, is_train)

if is_train:
# save category_id to label name mapping
save_labels(datasets, cfg.OUTPUT_DIR)

data_loaders = []
for dataset in datasets:
sampler = make_data_sampler(dataset, shuffle, is_distributed)
Expand Down
2 changes: 2 additions & 0 deletions maskrcnn_benchmark/data/datasets/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ def __init__(
ids.append(img_id)
self.ids = ids

self.categories = {cat['id']: cat['name'] for cat in self.coco.cats.values()}

self.json_category_id_to_contiguous_id = {
v: i + 1 for i, v in enumerate(self.coco.getCatIds())
}
Expand Down
1 change: 1 addition & 0 deletions maskrcnn_benchmark/data/datasets/voc.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def __init__(self, data_dir, split, use_difficult=False, transforms=None):

cls = PascalVOCDataset.CLASSES
self.class_to_ind = dict(zip(cls, range(len(cls))))
self.categories = dict(zip(range(len(cls)), cls))

def __getitem__(self, index):
img_id = self.ids[index]
Expand Down
21 changes: 21 additions & 0 deletions maskrcnn_benchmark/utils/miscellaneous.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import errno
import json
import logging
import os
from .comm import is_main_process

Expand All @@ -12,6 +14,25 @@ def mkdir(path):
raise


def save_labels(dataset_list, output_dir):
if is_main_process():
logger = logging.getLogger(__name__)

ids_to_labels = {}
for dataset in dataset_list:
if hasattr(dataset, 'categories'):
ids_to_labels.update(dataset.categories)
else:
logger.warning("Dataset [{}] has no categories attribute, labels.json file won't be created".format(
dataset.__class__.__name__))

if ids_to_labels:
labels_file = os.path.join(output_dir, 'labels.json')
logger.info("Saving labels mapping into {}".format(labels_file))
with open(labels_file, 'w') as f:
json.dump(ids_to_labels, f, indent=2)


def save_config(cfg, path):
if is_main_process():
with open(path, 'w') as f:
Expand Down