Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] fix merge_args in tools/test.py #2431

Merged
merged 3 commits into from
Jun 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
To utilize ViTPose, you'll need to have [MMClassification](https://github.com/open-mmlab/mmclassification). To install the required version, run the following command:

```shell
mim install 'mmcls>=1.0.0rc5'
mim install 'mmcls>=1.0.0rc6'
```

<!-- [BACKBONE] -->
Expand Down
34 changes: 19 additions & 15 deletions tools/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,20 @@ def parse_args():

def merge_args(cfg, args):
"""Merge CLI arguments to config."""

cfg.launcher = args.launcher
cfg.load_from = args.checkpoint

# -------------------- work directory --------------------
# work_dir is determined in this priority: CLI > segment in file > filename
if args.work_dir is not None:
# update configs according to CLI args if args.work_dir is not None
cfg.work_dir = args.work_dir
elif cfg.get('work_dir', None) is None:
# use config filename as default work_dir if cfg.work_dir is None
cfg.work_dir = osp.join('./work_dirs',
osp.splitext(osp.basename(args.config))[0])

# -------------------- visualization --------------------
if args.show or (args.show_dir is not None):
assert 'visualization' in cfg.default_hooks, \
Expand All @@ -80,10 +94,14 @@ def merge_args(cfg, args):
'The dump file must be a pkl file.'
dump_metric = dict(type='DumpResults', out_file_path=args.dump)
if isinstance(cfg.test_evaluator, (list, tuple)):
cfg.test_evaluator = list(cfg.test_evaluator).append(dump_metric)
cfg.test_evaluator = [*cfg.test_evaluator, dump_metric]
else:
cfg.test_evaluator = [cfg.test_evaluator, dump_metric]

# -------------------- Other arguments --------------------
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)

return cfg


Expand All @@ -93,20 +111,6 @@ def main():
# load config
cfg = Config.fromfile(args.config)
cfg = merge_args(cfg, args)
cfg.launcher = args.launcher
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)

# work_dir is determined in this priority: CLI > segment in file > filename
if args.work_dir is not None:
# update configs according to CLI args if args.work_dir is not None
cfg.work_dir = args.work_dir
elif cfg.get('work_dir', None) is None:
# use config filename as default work_dir if cfg.work_dir is None
cfg.work_dir = osp.join('./work_dirs',
osp.splitext(osp.basename(args.config))[0])

cfg.load_from = args.checkpoint

# build the runner from config
runner = Runner.from_cfg(cfg)
Expand Down
2 changes: 1 addition & 1 deletion tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def merge_args(cfg, args):
if args.auto_scale_lr:
cfg.auto_scale_lr.enable = True

# visualization-
# visualization
if args.show or (args.show_dir is not None):
assert 'visualization' in cfg.default_hooks, \
'PoseVisualizationHook is not set in the ' \
Expand Down