diff --git a/tools/test.py b/tools/test.py index 3b3364dacb..9f15ca8cb9 100644 --- a/tools/test.py +++ b/tools/test.py @@ -7,6 +7,7 @@ import torch from mmcv import Config, DictAction from mmcv.cnn import fuse_conv_bn +from mmcv.fileio.io import file_handlers from mmcv.parallel import MMDataParallel, MMDistributedDataParallel from mmcv.runner import get_dist_info, init_dist, load_checkpoint from mmcv.runner.fp16_utils import wrap_fp16_model @@ -92,17 +93,6 @@ def parse_args(): return args -def merge_configs(cfg1, cfg2): - # Merge cfg2 into cfg1 - # Overwrite cfg1 if repeated, ignore if value is None. - cfg1 = {} if cfg1 is None else cfg1.copy() - cfg2 = {} if cfg2 is None else cfg2 - for k, v in cfg2.items(): - if v: - cfg1[k] = v - return cfg1 - - def main(): args = parse_args() @@ -113,19 +103,27 @@ def main(): # Load output_config from cfg output_config = cfg.get('output_config', {}) # Overwrite output_config from args.out - output_config = merge_configs(output_config, dict(out=args.out)) + output_config = Config._merge_a_into_b(dict(out=args.out), output_config) # Load eval_config from cfg eval_config = cfg.get('eval_config', {}) # Overwrite eval_config from args.eval - eval_config = merge_configs(eval_config, dict(metrics=args.eval)) + eval_config = Config._merge_a_into_b(dict(metrics=args.eval), eval_config) # Add options from args.eval_options - eval_config = merge_configs(eval_config, args.eval_options) + eval_config = Config._merge_a_into_b(args.eval_options, eval_config) assert output_config or eval_config, \ ('Please specify at least one operation (save or eval the ' 'results) with the argument "--out" or "--eval"') + if output_config: + out = output_config['out'] + # make sure the dirname of the output path exists + mmcv.mkdir_or_exist(osp.dirname(out)) + _, suffix = osp.splitext(out) + assert suffix in file_handlers, \ + 'The format of the output file should be json, pickle or yaml' + # set cudnn benchmark if cfg.get('cudnn_benchmark', False): torch.backends.cudnn.benchmark = True @@ -146,8 +144,6 @@ def main(): distributed = True init_dist(args.launcher, **cfg.dist_params) - # create work_dir - mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir)) # build the dataloader dataset = build_dataset(cfg.data.test, dict(test_mode=True)) dataloader_setting = dict(