Skip to content

Commit

Permalink
allow evaluation on custom images
Browse files Browse the repository at this point in the history
  • Loading branch information
KaihuaTang committed Jun 23, 2020
1 parent d05be9f commit fca9860
Show file tree
Hide file tree
Showing 9 changed files with 60 additions and 13 deletions.
19 changes: 15 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ Our paper [Unbiased Scene Graph Generation from Biased Training](https://arxiv.o

## Recent Updates

- [x] 2020.06.23 [No Graph Constraint Mean Recall@K (ng-mR@K) and No Graph Constraint Zero-Shot Recall@K (ng-zR@K)](METRICS.md#explanation-of-our-metrics)
- [x] 2020.06.23 Add No Graph Constraint Mean Recall@K (ng-mR@K) and No Graph Constraint Zero-Shot Recall@K (ng-zR@K)[link](METRICS.md#explanation-of-our-metrics)
- [x] 2020.06.23 Allow Scene Graph Detection (SGDet) on Custom Images[link](#run-SGDet-on-custom-images)

## Contents

Expand All @@ -23,9 +24,10 @@ Our paper [Unbiased Scene Graph Generation from Biased Training](https://arxiv.o
6. [Scene Graph Generation as RoI_Head](#scene-graph-generation-as-RoI_Head)
7. [Training on Scene Graph Generation](#perform-training-on-scene-graph-generation)
8. [Evaluation on Scene Graph Generation](#Evaluation)
9. [Other Options that May Improve the SGG](#other-options-that-may-improve-the-SGG)
10. [Tips and Tricks for TDE on any Unbiased Task](#tips-and-Tricks-for-any-unbiased-taskX-from-biased-training)
11. [Citations](#Citations)
9. [SGDet on Custum Images](#run-SGDet-on-custom-images)
10. [Other Options that May Improve the SGG](#other-options-that-may-improve-the-SGG)
11. [Tips and Tricks for TDE on any Unbiased Task](#tips-and-Tricks-for-any-unbiased-taskX-from-biased-training)
12. [Citations](#Citations)

## Overview

Expand Down Expand Up @@ -168,6 +170,15 @@ MOTIFS-SGCls-TDE | 20.47 | 26.31 | 28.79 | 9.80 | 13.21 | 15.06 | 1.91 | 2.95
MOTIFS-PredCls-none | 59.64 | 66.11 | 67.96 | 11.46 | 14.60 | 15.84 | 5.79 | 11.02 | 14.74
MOTIFS-PredCls-TDE | 33.38 | 45.88 | 51.25 | 17.85 | 24.75 | 28.70 | 8.28 | 14.31 | 18.04

## Run SGDet on Custom Images
Note that evaluation on custum images is only valid for SGDet model, because PredCls and SGCls model requires additional ground-truth bounding boxes information. You only need to turn on the switch TEST.CUSTUM_EVAL and give a folder path that contains the custom images to TEST.CUSTUM_PATH. Only JPG files are allowed. The output will be custom_prediction.pytorch saved in OUTPUT_DIR, which can be read by torch.load().

Test Example 1 : (SGDet, Motif Model)
```bash
CUDA_VISIBLE_DEVICES=0 python -m torch.distributed.launch --master_port 10027 --nproc_per_node=1 tools/relation_test_net.py --config-file "configs/e2e_relation_X_101_32_8_FPN_1x.yaml" MODEL.ROI_RELATION_HEAD.USE_GT_BOX False MODEL.ROI_RELATION_HEAD.USE_GT_OBJECT_LABEL False MODEL.ROI_RELATION_HEAD.PREDICTOR MotifPredictor TEST.IMS_PER_BATCH 1 DTYPE "float16" GLOVE_DIR /home/kaihua/glove MODEL.PRETRAINED_DETECTOR_CKPT /home/kaihua/checkpoints/motif-precls-exmp OUTPUT_DIR /home/kaihua/checkpoints/motif-precls-exmp TEST.CUSTUM_EVAL True TEST.CUSTUM_PATH /home/kaihua/checkpoints/custom_images
```


## Other Options that May Improve the SGG

- For some models (not all), turning on or turning off ```MODEL.ROI_RELATION_HEAD.POOLING_ALL_LEVELS``` will affect the performance of predicate prediction, e.g., turning it off will improve VCTree PredCls but not the corresponding SGCls and SGGen. For the reported results of VCTree, we simply turn it on for all three protocols like other models.
Expand Down
2 changes: 2 additions & 0 deletions configs/e2e_relation_X_101_32_8_FPN_1x.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,5 @@ TEST:
SYNC_GATHER: True # turn on will slow down the evaluation to solve the sgdet test out of memory problem
REQUIRE_OVERLAP: False
LATER_NMS_PREDICTION_THRES: 0.5
CUSTUM_EVAL: False # eval SGDet model on custum images, output a json
CUSTUM_PATH: '.' # the folder that contains the custum images, only jpg files are allowed
5 changes: 5 additions & 0 deletions maskrcnn_benchmark/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,6 +574,11 @@
_C.TEST.RELATION.SYNC_GATHER = False

_C.TEST.ALLOW_LOAD_FROM_CACHE = True


_C.TEST.CUSTUM_EVAL = False
_C.TEST.CUSTUM_PATH = '.'

# ---------------------------------------------------------------------------- #
# Misc options
# ---------------------------------------------------------------------------- #
Expand Down
4 changes: 3 additions & 1 deletion maskrcnn_benchmark/config/paths_catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,9 @@ def get(name, cfg):
# else set filter to False, because we need all images for pretraining detector
args['filter_non_overlap'] = (not cfg.MODEL.ROI_RELATION_HEAD.USE_GT_BOX) and cfg.MODEL.RELATION_ON and cfg.MODEL.ROI_RELATION_HEAD.REQUIRE_BOX_OVERLAP
args['filter_empty_rels'] = cfg.MODEL.RELATION_ON
args['flip_aug'] = cfg.MODEL.FLIP_AUG
args['flip_aug'] = cfg.MODEL.FLIP_AUG
args['custom_eval'] = cfg.TEST.CUSTUM_EVAL
args['custom_path'] = cfg.TEST.CUSTUM_PATH
return dict(
factory="VGDataset",
args=args,
Expand Down
4 changes: 4 additions & 0 deletions maskrcnn_benchmark/data/datasets/evaluation/vg/sgg_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,9 +314,11 @@ def generate_print_string(self, mode):
result_str += ' for mode=%s, type=Mean Recall.' % mode
result_str += '\n'
if self.print_detail:
result_str += '----------------------- Details ------------------------\n'
for n, r in zip(self.rel_name_list, self.result_dict[mode + '_mean_recall_list'][100]):
result_str += '({}:{:.4f}) '.format(str(n), r)
result_str += '\n'
result_str += '--------------------------------------------------------\n'

return result_str

Expand Down Expand Up @@ -384,9 +386,11 @@ def generate_print_string(self, mode):
result_str += ' for mode=%s, type=No Graph Constraint Mean Recall.' % mode
result_str += '\n'
if self.print_detail:
result_str += '----------------------- Details ------------------------\n'
for n, r in zip(self.rel_name_list, self.result_dict[mode + '_ng_mean_recall_list'][100]):
result_str += '({}:{:.4f}) '.format(str(n), r)
result_str += '\n'
result_str += '--------------------------------------------------------\n'

return result_str

Expand Down
15 changes: 14 additions & 1 deletion maskrcnn_benchmark/data/datasets/visual_genome.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class VGDataset(torch.utils.data.Dataset):

def __init__(self, split, img_dir, roidb_file, dict_file, image_file, transforms=None,
filter_empty_rels=True, num_im=-1, num_val_im=5000,
filter_duplicate_rels=True, filter_non_overlap=True, flip_aug=False):
filter_duplicate_rels=True, filter_non_overlap=True, flip_aug=False, custom_eval=False, custom_path=''):
"""
Torch dataset for VisualGenome
Parameters:
Expand Down Expand Up @@ -63,11 +63,18 @@ def __init__(self, split, img_dir, roidb_file, dict_file, image_file, transforms
self.filenames = [self.filenames[i] for i in np.where(self.split_mask)[0]]
self.img_info = [self.img_info[i] for i in np.where(self.split_mask)[0]]

self.custom_eval = custom_eval
if self.custom_eval:
self.get_custom_imgs(custom_path)


def __getitem__(self, index):
#if self.split == 'train':
# while(random.random() > self.img_info[index]['anti_prop']):
# index = int(random.random() * len(self.filenames))
if self.custom_eval:
img = Image.open(self.custom_files[index]).convert("RGB")
return img, 0, index

img = Image.open(self.filenames[index]).convert("RGB")
if img.size[0] != self.img_info[index]['width'] or img.size[1] != self.img_info[index]['height']:
Expand Down Expand Up @@ -103,6 +110,10 @@ def get_statistics(self):
}
return result

def get_custom_imgs(self, path):
self.custom_files = []
for file_name in os.listdir(path):
self.custom_files.append(os.path.join(path, file_name))

def get_img_info(self, index):
# WARNING: original image_file.json has several pictures with false image size
Expand Down Expand Up @@ -159,6 +170,8 @@ def get_groundtruth(self, index, evaluation=False, flip_img=False):
return target

def __len__(self):
if self.custom_eval:
return len(self.custom_files)
return len(self.filenames)


Expand Down
4 changes: 4 additions & 0 deletions maskrcnn_benchmark/engine/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,10 @@ def inference(
expected_results_sigma_tol=expected_results_sigma_tol,
)

if cfg.TEST.CUSTUM_EVAL:
torch.save(predictions, os.path.join(cfg.OUTPUT_DIR, 'custom_prediction.pytorch'))
return -1.0

return evaluate(cfg=cfg,
dataset=dataset,
predictions=predictions,
Expand Down
3 changes: 2 additions & 1 deletion maskrcnn_benchmark/modeling/roi_heads/box_head/box_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ def forward(self, features, proposals, targets=None):
return x, proposals, {}
else:
# mode==sgdet
proposals = self.samp_processor.assign_label_to_proposals(proposals, targets)
if self.training or not self.cfg.TEST.CUSTUM_EVAL:
proposals = self.samp_processor.assign_label_to_proposals(proposals, targets)
x = self.feature_extractor(features, proposals)
class_logits, box_regression = self.predictor(x)
proposals = add_predict_logits(proposals, class_logits)
Expand Down
17 changes: 11 additions & 6 deletions maskrcnn_benchmark/modeling/roi_heads/box_head/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ def __init__(
box_coder=None,
cls_agnostic_bbox_reg=False,
bbox_aug_enabled=False,
save_proposals=False
save_proposals=False,
custum_eval=False
):
"""
Arguments:
Expand All @@ -47,6 +48,7 @@ def __init__(
self.cls_agnostic_bbox_reg = cls_agnostic_bbox_reg
self.bbox_aug_enabled = bbox_aug_enabled
self.save_proposals = save_proposals
self.custum_eval = custum_eval

def forward(self, x, boxes, relation_mode=False):
"""
Expand Down Expand Up @@ -104,11 +106,12 @@ def forward(self, x, boxes, relation_mode=False):

def add_important_fields(self, i, boxes, orig_inds, boxlist, boxes_per_cls, relation_mode=False):
if relation_mode:
gt_labels = boxes[i].get_field('labels')[orig_inds]
gt_attributes = boxes[i].get_field('attributes')[orig_inds]
if not self.custum_eval:
gt_labels = boxes[i].get_field('labels')[orig_inds]
gt_attributes = boxes[i].get_field('attributes')[orig_inds]

boxlist.add_field('labels', gt_labels)
boxlist.add_field('attributes', gt_attributes)
boxlist.add_field('labels', gt_labels)
boxlist.add_field('attributes', gt_attributes)

predict_logits = boxes[i].get_field('predict_logits')[orig_inds]
boxlist.add_field('boxes_per_cls', boxes_per_cls)
Expand Down Expand Up @@ -238,6 +241,7 @@ def make_roi_box_post_processor(cfg):
post_nms_per_cls_topn = cfg.MODEL.ROI_HEADS.POST_NMS_PER_CLS_TOPN
nms_filter_duplicates = cfg.MODEL.ROI_HEADS.NMS_FILTER_DUPLICATES
save_proposals = cfg.TEST.SAVE_PROPOSALS
custum_eval = cfg.TEST.CUSTUM_EVAL

postprocessor = PostProcessor(
score_thresh,
Expand All @@ -248,6 +252,7 @@ def make_roi_box_post_processor(cfg):
box_coder,
cls_agnostic_bbox_reg,
bbox_aug_enabled,
save_proposals
save_proposals,
custum_eval
)
return postprocessor

0 comments on commit fca9860

Please sign in to comment.