From fca98604916f9fb2fbeca4fbf430c5b515b42a91 Mon Sep 17 00:00:00 2001 From: "DESKTOP-PS56VT3\\Admin" Date: Tue, 23 Jun 2020 23:01:18 +0800 Subject: [PATCH] allow evaluation on custom images --- README.md | 19 +++++++++++++++---- configs/e2e_relation_X_101_32_8_FPN_1x.yaml | 2 ++ maskrcnn_benchmark/config/defaults.py | 5 +++++ maskrcnn_benchmark/config/paths_catalog.py | 4 +++- .../data/datasets/evaluation/vg/sgg_eval.py | 4 ++++ .../data/datasets/visual_genome.py | 15 ++++++++++++++- maskrcnn_benchmark/engine/inference.py | 4 ++++ .../modeling/roi_heads/box_head/box_head.py | 3 ++- .../modeling/roi_heads/box_head/inference.py | 17 +++++++++++------ 9 files changed, 60 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index 23b0311..545940b 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,8 @@ Our paper [Unbiased Scene Graph Generation from Biased Training](https://arxiv.o ## Recent Updates -- [x] 2020.06.23 [No Graph Constraint Mean Recall@K (ng-mR@K) and No Graph Constraint Zero-Shot Recall@K (ng-zR@K)](METRICS.md#explanation-of-our-metrics) +- [x] 2020.06.23 Add No Graph Constraint Mean Recall@K (ng-mR@K) and No Graph Constraint Zero-Shot Recall@K (ng-zR@K)[link](METRICS.md#explanation-of-our-metrics) +- [x] 2020.06.23 Allow Scene Graph Detection (SGDet) on Custom Images[link](#run-SGDet-on-custom-images) ## Contents @@ -23,9 +24,10 @@ Our paper [Unbiased Scene Graph Generation from Biased Training](https://arxiv.o 6. [Scene Graph Generation as RoI_Head](#scene-graph-generation-as-RoI_Head) 7. [Training on Scene Graph Generation](#perform-training-on-scene-graph-generation) 8. [Evaluation on Scene Graph Generation](#Evaluation) -9. [Other Options that May Improve the SGG](#other-options-that-may-improve-the-SGG) -10. [Tips and Tricks for TDE on any Unbiased Task](#tips-and-Tricks-for-any-unbiased-taskX-from-biased-training) -11. [Citations](#Citations) +9. [SGDet on Custum Images](#run-SGDet-on-custom-images) +10. [Other Options that May Improve the SGG](#other-options-that-may-improve-the-SGG) +11. [Tips and Tricks for TDE on any Unbiased Task](#tips-and-Tricks-for-any-unbiased-taskX-from-biased-training) +12. [Citations](#Citations) ## Overview @@ -168,6 +170,15 @@ MOTIFS-SGCls-TDE | 20.47 | 26.31 | 28.79 | 9.80 | 13.21 | 15.06 | 1.91 | 2.95 MOTIFS-PredCls-none | 59.64 | 66.11 | 67.96 | 11.46 | 14.60 | 15.84 | 5.79 | 11.02 | 14.74 MOTIFS-PredCls-TDE | 33.38 | 45.88 | 51.25 | 17.85 | 24.75 | 28.70 | 8.28 | 14.31 | 18.04 +## Run SGDet on Custom Images +Note that evaluation on custum images is only valid for SGDet model, because PredCls and SGCls model requires additional ground-truth bounding boxes information. You only need to turn on the switch TEST.CUSTUM_EVAL and give a folder path that contains the custom images to TEST.CUSTUM_PATH. Only JPG files are allowed. The output will be custom_prediction.pytorch saved in OUTPUT_DIR, which can be read by torch.load(). + +Test Example 1 : (SGDet, Motif Model) +```bash +CUDA_VISIBLE_DEVICES=0 python -m torch.distributed.launch --master_port 10027 --nproc_per_node=1 tools/relation_test_net.py --config-file "configs/e2e_relation_X_101_32_8_FPN_1x.yaml" MODEL.ROI_RELATION_HEAD.USE_GT_BOX False MODEL.ROI_RELATION_HEAD.USE_GT_OBJECT_LABEL False MODEL.ROI_RELATION_HEAD.PREDICTOR MotifPredictor TEST.IMS_PER_BATCH 1 DTYPE "float16" GLOVE_DIR /home/kaihua/glove MODEL.PRETRAINED_DETECTOR_CKPT /home/kaihua/checkpoints/motif-precls-exmp OUTPUT_DIR /home/kaihua/checkpoints/motif-precls-exmp TEST.CUSTUM_EVAL True TEST.CUSTUM_PATH /home/kaihua/checkpoints/custom_images +``` + + ## Other Options that May Improve the SGG - For some models (not all), turning on or turning off ```MODEL.ROI_RELATION_HEAD.POOLING_ALL_LEVELS``` will affect the performance of predicate prediction, e.g., turning it off will improve VCTree PredCls but not the corresponding SGCls and SGGen. For the reported results of VCTree, we simply turn it on for all three protocols like other models. diff --git a/configs/e2e_relation_X_101_32_8_FPN_1x.yaml b/configs/e2e_relation_X_101_32_8_FPN_1x.yaml index 5faf72a..66d6d82 100644 --- a/configs/e2e_relation_X_101_32_8_FPN_1x.yaml +++ b/configs/e2e_relation_X_101_32_8_FPN_1x.yaml @@ -126,3 +126,5 @@ TEST: SYNC_GATHER: True # turn on will slow down the evaluation to solve the sgdet test out of memory problem REQUIRE_OVERLAP: False LATER_NMS_PREDICTION_THRES: 0.5 + CUSTUM_EVAL: False # eval SGDet model on custum images, output a json + CUSTUM_PATH: '.' # the folder that contains the custum images, only jpg files are allowed diff --git a/maskrcnn_benchmark/config/defaults.py b/maskrcnn_benchmark/config/defaults.py index 9cab043..0aeffda 100644 --- a/maskrcnn_benchmark/config/defaults.py +++ b/maskrcnn_benchmark/config/defaults.py @@ -574,6 +574,11 @@ _C.TEST.RELATION.SYNC_GATHER = False _C.TEST.ALLOW_LOAD_FROM_CACHE = True + + +_C.TEST.CUSTUM_EVAL = False +_C.TEST.CUSTUM_PATH = '.' + # ---------------------------------------------------------------------------- # # Misc options # ---------------------------------------------------------------------------- # diff --git a/maskrcnn_benchmark/config/paths_catalog.py b/maskrcnn_benchmark/config/paths_catalog.py index bc417a5..7dd5370 100644 --- a/maskrcnn_benchmark/config/paths_catalog.py +++ b/maskrcnn_benchmark/config/paths_catalog.py @@ -157,7 +157,9 @@ def get(name, cfg): # else set filter to False, because we need all images for pretraining detector args['filter_non_overlap'] = (not cfg.MODEL.ROI_RELATION_HEAD.USE_GT_BOX) and cfg.MODEL.RELATION_ON and cfg.MODEL.ROI_RELATION_HEAD.REQUIRE_BOX_OVERLAP args['filter_empty_rels'] = cfg.MODEL.RELATION_ON - args['flip_aug'] = cfg.MODEL.FLIP_AUG + args['flip_aug'] = cfg.MODEL.FLIP_AUG + args['custom_eval'] = cfg.TEST.CUSTUM_EVAL + args['custom_path'] = cfg.TEST.CUSTUM_PATH return dict( factory="VGDataset", args=args, diff --git a/maskrcnn_benchmark/data/datasets/evaluation/vg/sgg_eval.py b/maskrcnn_benchmark/data/datasets/evaluation/vg/sgg_eval.py index 8568253..0b56aa0 100644 --- a/maskrcnn_benchmark/data/datasets/evaluation/vg/sgg_eval.py +++ b/maskrcnn_benchmark/data/datasets/evaluation/vg/sgg_eval.py @@ -314,9 +314,11 @@ def generate_print_string(self, mode): result_str += ' for mode=%s, type=Mean Recall.' % mode result_str += '\n' if self.print_detail: + result_str += '----------------------- Details ------------------------\n' for n, r in zip(self.rel_name_list, self.result_dict[mode + '_mean_recall_list'][100]): result_str += '({}:{:.4f}) '.format(str(n), r) result_str += '\n' + result_str += '--------------------------------------------------------\n' return result_str @@ -384,9 +386,11 @@ def generate_print_string(self, mode): result_str += ' for mode=%s, type=No Graph Constraint Mean Recall.' % mode result_str += '\n' if self.print_detail: + result_str += '----------------------- Details ------------------------\n' for n, r in zip(self.rel_name_list, self.result_dict[mode + '_ng_mean_recall_list'][100]): result_str += '({}:{:.4f}) '.format(str(n), r) result_str += '\n' + result_str += '--------------------------------------------------------\n' return result_str diff --git a/maskrcnn_benchmark/data/datasets/visual_genome.py b/maskrcnn_benchmark/data/datasets/visual_genome.py index ea9a4da..9f4c8b1 100644 --- a/maskrcnn_benchmark/data/datasets/visual_genome.py +++ b/maskrcnn_benchmark/data/datasets/visual_genome.py @@ -18,7 +18,7 @@ class VGDataset(torch.utils.data.Dataset): def __init__(self, split, img_dir, roidb_file, dict_file, image_file, transforms=None, filter_empty_rels=True, num_im=-1, num_val_im=5000, - filter_duplicate_rels=True, filter_non_overlap=True, flip_aug=False): + filter_duplicate_rels=True, filter_non_overlap=True, flip_aug=False, custom_eval=False, custom_path=''): """ Torch dataset for VisualGenome Parameters: @@ -63,11 +63,18 @@ def __init__(self, split, img_dir, roidb_file, dict_file, image_file, transforms self.filenames = [self.filenames[i] for i in np.where(self.split_mask)[0]] self.img_info = [self.img_info[i] for i in np.where(self.split_mask)[0]] + self.custom_eval = custom_eval + if self.custom_eval: + self.get_custom_imgs(custom_path) + def __getitem__(self, index): #if self.split == 'train': # while(random.random() > self.img_info[index]['anti_prop']): # index = int(random.random() * len(self.filenames)) + if self.custom_eval: + img = Image.open(self.custom_files[index]).convert("RGB") + return img, 0, index img = Image.open(self.filenames[index]).convert("RGB") if img.size[0] != self.img_info[index]['width'] or img.size[1] != self.img_info[index]['height']: @@ -103,6 +110,10 @@ def get_statistics(self): } return result + def get_custom_imgs(self, path): + self.custom_files = [] + for file_name in os.listdir(path): + self.custom_files.append(os.path.join(path, file_name)) def get_img_info(self, index): # WARNING: original image_file.json has several pictures with false image size @@ -159,6 +170,8 @@ def get_groundtruth(self, index, evaluation=False, flip_img=False): return target def __len__(self): + if self.custom_eval: + return len(self.custom_files) return len(self.filenames) diff --git a/maskrcnn_benchmark/engine/inference.py b/maskrcnn_benchmark/engine/inference.py index 47c6724..ee1db55 100644 --- a/maskrcnn_benchmark/engine/inference.py +++ b/maskrcnn_benchmark/engine/inference.py @@ -141,6 +141,10 @@ def inference( expected_results_sigma_tol=expected_results_sigma_tol, ) + if cfg.TEST.CUSTUM_EVAL: + torch.save(predictions, os.path.join(cfg.OUTPUT_DIR, 'custom_prediction.pytorch')) + return -1.0 + return evaluate(cfg=cfg, dataset=dataset, predictions=predictions, diff --git a/maskrcnn_benchmark/modeling/roi_heads/box_head/box_head.py b/maskrcnn_benchmark/modeling/roi_heads/box_head/box_head.py index f9bcff6..057038c 100644 --- a/maskrcnn_benchmark/modeling/roi_heads/box_head/box_head.py +++ b/maskrcnn_benchmark/modeling/roi_heads/box_head/box_head.py @@ -64,7 +64,8 @@ def forward(self, features, proposals, targets=None): return x, proposals, {} else: # mode==sgdet - proposals = self.samp_processor.assign_label_to_proposals(proposals, targets) + if self.training or not self.cfg.TEST.CUSTUM_EVAL: + proposals = self.samp_processor.assign_label_to_proposals(proposals, targets) x = self.feature_extractor(features, proposals) class_logits, box_regression = self.predictor(x) proposals = add_predict_logits(proposals, class_logits) diff --git a/maskrcnn_benchmark/modeling/roi_heads/box_head/inference.py b/maskrcnn_benchmark/modeling/roi_heads/box_head/inference.py index 3c54d6e..b9eb3d2 100644 --- a/maskrcnn_benchmark/modeling/roi_heads/box_head/inference.py +++ b/maskrcnn_benchmark/modeling/roi_heads/box_head/inference.py @@ -26,7 +26,8 @@ def __init__( box_coder=None, cls_agnostic_bbox_reg=False, bbox_aug_enabled=False, - save_proposals=False + save_proposals=False, + custum_eval=False ): """ Arguments: @@ -47,6 +48,7 @@ def __init__( self.cls_agnostic_bbox_reg = cls_agnostic_bbox_reg self.bbox_aug_enabled = bbox_aug_enabled self.save_proposals = save_proposals + self.custum_eval = custum_eval def forward(self, x, boxes, relation_mode=False): """ @@ -104,11 +106,12 @@ def forward(self, x, boxes, relation_mode=False): def add_important_fields(self, i, boxes, orig_inds, boxlist, boxes_per_cls, relation_mode=False): if relation_mode: - gt_labels = boxes[i].get_field('labels')[orig_inds] - gt_attributes = boxes[i].get_field('attributes')[orig_inds] + if not self.custum_eval: + gt_labels = boxes[i].get_field('labels')[orig_inds] + gt_attributes = boxes[i].get_field('attributes')[orig_inds] - boxlist.add_field('labels', gt_labels) - boxlist.add_field('attributes', gt_attributes) + boxlist.add_field('labels', gt_labels) + boxlist.add_field('attributes', gt_attributes) predict_logits = boxes[i].get_field('predict_logits')[orig_inds] boxlist.add_field('boxes_per_cls', boxes_per_cls) @@ -238,6 +241,7 @@ def make_roi_box_post_processor(cfg): post_nms_per_cls_topn = cfg.MODEL.ROI_HEADS.POST_NMS_PER_CLS_TOPN nms_filter_duplicates = cfg.MODEL.ROI_HEADS.NMS_FILTER_DUPLICATES save_proposals = cfg.TEST.SAVE_PROPOSALS + custum_eval = cfg.TEST.CUSTUM_EVAL postprocessor = PostProcessor( score_thresh, @@ -248,6 +252,7 @@ def make_roi_box_post_processor(cfg): box_coder, cls_agnostic_bbox_reg, bbox_aug_enabled, - save_proposals + save_proposals, + custum_eval ) return postprocessor