Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

Commit

Permalink
add support for pascal voc dataset and evaluate (#207)
Browse files Browse the repository at this point in the history
* add support for pascal voc dataset and evaluate

* optimization for adding voc dataset

* make inference.py dataset-agnostic; add use_difficult option to voc dataset

* handle voc difficult objects correctly

* Remove dependency on lxml plus minor improvements

* More cleanups

* More comments and improvements

* Lint fix

* Move configs to their own folder
  • Loading branch information
fmassa authored Nov 23, 2018
1 parent 7bc8708 commit 9a1ba14
Show file tree
Hide file tree
Showing 17 changed files with 860 additions and 361 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ ln -s /path_to_coco_dataset/annotations datasets/coco/annotations
ln -s /path_to_coco_dataset/train2014 datasets/coco/train2014
ln -s /path_to_coco_dataset/test2014 datasets/coco/test2014
ln -s /path_to_coco_dataset/val2014 datasets/coco/val2014
# for pascal voc dataset:
ln -s /path_to_VOCdevkit_dir datasets/voc
```

You can also configure your own paths to the datasets.
Expand Down
20 changes: 20 additions & 0 deletions configs/pascal_voc/e2e_faster_rcnn_R_50_C4_1x_1_gpu_voc.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
MODEL:
META_ARCHITECTURE: "GeneralizedRCNN"
WEIGHT: "catalog://ImageNetPretrained/MSRA/R-50"
RPN:
PRE_NMS_TOP_N_TEST: 6000
POST_NMS_TOP_N_TEST: 300
ANCHOR_SIZES: (128, 256, 512)
ROI_BOX_HEAD:
NUM_CLASSES: 21
DATASETS:
TRAIN: ("voc_2007_trainval",)
TEST: ("voc_2007_test",)
SOLVER:
BASE_LR: 0.001
WEIGHT_DECAY: 0.0001
STEPS: (50000, )
MAX_ITER: 70000
IMS_PER_BATCH: 1
TEST:
IMS_PER_BATCH: 1
20 changes: 20 additions & 0 deletions configs/pascal_voc/e2e_faster_rcnn_R_50_C4_1x_4_gpu_voc.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
MODEL:
META_ARCHITECTURE: "GeneralizedRCNN"
WEIGHT: "catalog://ImageNetPretrained/MSRA/R-50"
RPN:
PRE_NMS_TOP_N_TEST: 6000
POST_NMS_TOP_N_TEST: 300
ANCHOR_SIZES: (128, 256, 512)
ROI_BOX_HEAD:
NUM_CLASSES: 21
DATASETS:
TRAIN: ("voc_2007_trainval",)
TEST: ("voc_2007_test",)
SOLVER:
BASE_LR: 0.004
WEIGHT_DECAY: 0.0001
STEPS: (12500, )
MAX_ITER: 17500
IMS_PER_BATCH: 4
TEST:
IMS_PER_BATCH: 4
18 changes: 18 additions & 0 deletions maskrcnn_benchmark/config/paths_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@ class DatasetCatalog(object):
"coco/val2014",
"coco/annotations/instances_valminusminival2014.json",
),
"voc_2007_trainval": ("voc/VOC2007", 'trainval'),
"voc_2007_test": ("voc/VOC2007", 'test'),
"voc_2012_train": ("voc/VOC2012", 'train'),
"voc_2012_trainval": ("voc/VOC2012", 'trainval'),
"voc_2012_val": ("voc/VOC2012", 'val'),
"voc_2012_test": ("voc/VOC2012", 'test'),

}

@staticmethod
Expand All @@ -36,6 +43,17 @@ def get(name):
factory="COCODataset",
args=args,
)
elif "voc" in name:
data_dir = DatasetCatalog.DATA_DIR
attrs = DatasetCatalog.DATASETS[name]
args = dict(
data_dir=os.path.join(data_dir, attrs[0]),
split=attrs[1],
)
return dict(
factory="PascalVOCDataset",
args=args,
)
raise RuntimeError("Dataset not available: {}".format(name))


Expand Down
9 changes: 7 additions & 2 deletions maskrcnn_benchmark/data/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ def build_dataset(dataset_list, transforms, dataset_catalog, is_train=True):
"""
if not isinstance(dataset_list, (list, tuple)):
raise RuntimeError(
"dataset_list should be a list of strings, got {}".format(dataset_list))
"dataset_list should be a list of strings, got {}".format(dataset_list)
)
datasets = []
for dataset_name in dataset_list:
data = dataset_catalog.get(dataset_name)
Expand All @@ -36,6 +37,8 @@ def build_dataset(dataset_list, transforms, dataset_catalog, is_train=True):
# during training
if data["factory"] == "COCODataset":
args["remove_images_without_annotations"] = is_train
if data["factory"] == "PascalVOCDataset":
args["use_difficult"] = not is_train
args["transforms"] = transforms
# make dataset from factory
dataset = factory(**args)
Expand Down Expand Up @@ -95,7 +98,9 @@ def make_batch_data_sampler(
sampler, images_per_batch, drop_last=False
)
if num_iters is not None:
batch_sampler = samplers.IterationBasedBatchSampler(batch_sampler, num_iters, start_iter)
batch_sampler = samplers.IterationBasedBatchSampler(
batch_sampler, num_iters, start_iter
)
return batch_sampler


Expand Down
3 changes: 2 additions & 1 deletion maskrcnn_benchmark/data/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
from .coco import COCODataset
from .voc import PascalVOCDataset
from .concat_dataset import ConcatDataset

__all__ = ["COCODataset", "ConcatDataset"]
__all__ = ["COCODataset", "ConcatDataset", "PascalVOCDataset"]
1 change: 0 additions & 1 deletion maskrcnn_benchmark/data/datasets/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ def __init__(
self, ann_file, root, remove_images_without_annotations, transforms=None
):
super(COCODataset, self).__init__(root, ann_file)

# sort indices for reproducible results
self.ids = sorted(self.ids)

Expand Down
27 changes: 27 additions & 0 deletions maskrcnn_benchmark/data/datasets/evaluation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from maskrcnn_benchmark.data import datasets

from .coco import coco_evaluation
from .voc import voc_evaluation


def evaluate(dataset, predictions, output_folder, **kwargs):
"""evaluate dataset using different methods based on dataset type.
Args:
dataset: Dataset object
predictions(list[BoxList]): each item in the list represents the
prediction results for one image.
output_folder: output folder, to save evaluation files or results.
**kwargs: other args.
Returns:
evaluation result
"""
args = dict(
dataset=dataset, predictions=predictions, output_folder=output_folder, **kwargs
)
if isinstance(dataset, datasets.COCODataset):
return coco_evaluation(**args)
elif isinstance(dataset, datasets.PascalVOCDataset):
return voc_evaluation(**args)
else:
dataset_name = dataset.__class__.__name__
raise NotImplementedError("Unsupported dataset type {}.".format(dataset_name))
21 changes: 21 additions & 0 deletions maskrcnn_benchmark/data/datasets/evaluation/coco/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from .coco_eval import do_coco_evaluation


def coco_evaluation(
dataset,
predictions,
output_folder,
box_only,
iou_types,
expected_results,
expected_results_sigma_tol,
):
return do_coco_evaluation(
dataset=dataset,
predictions=predictions,
box_only=box_only,
output_folder=output_folder,
iou_types=iou_types,
expected_results=expected_results,
expected_results_sigma_tol=expected_results_sigma_tol,
)
Loading

2 comments on commit 9a1ba14

@zimenglan-sysu-512
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi @fmassa
i have a question: why do u add dataset_name as a parameter in inference function in tools/test_net.py file?

@fmassa
Copy link
Contributor Author

@fmassa fmassa commented on 9a1ba14 Nov 25, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi,
It is now only needed for nicer logging / error messages I believe, but those checks could be pushed outside I think

Please sign in to comment.