Skip to content

Commit

Permalink
[Feature] Support to set data root through commands (open-mmlab#7386)
Browse files Browse the repository at this point in the history
* Fix open-mmlab#6915: Support to set data root through commands

* Support open-mmlab#6915: seperate function in tools/utils.py, support test.py and browse_dataset.py

* update open-mmlab#6915: refactor the code ref @hhaAndroid advice

* support open-mmlab#6915: fix format problem

* supoort corresponding scripts and update doc @hhaAndroid

* updata misc.py as @ZwwWayne and @hhaAndroid

* Update mmdet/utils/misc.py

Co-authored-by: Haian Huang(深度眸) <[email protected]>

* fix mmdet/utils/misc.py format problem

Co-authored-by: Haian Huang(深度眸) <[email protected]>
  • Loading branch information
2 people authored and ZwwWayne committed Jul 19, 2022
1 parent f9b31c9 commit 14b1b4e
Show file tree
Hide file tree
Showing 13 changed files with 100 additions and 5 deletions.
6 changes: 6 additions & 0 deletions docs/en/3_exist_data_new_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ mmdetection
```

Or you can set your dataset root through
```bash
export MMDET_DATASETS=$data_root
```
We will replace dataset root with `$MMDET_DATASETS`, so you don't have to modify the corresponding path in config files.

The cityscapes annotations have to be converted into the coco format using `tools/dataset_converters/cityscapes.py`:

```shell
Expand Down
7 changes: 7 additions & 0 deletions docs/zh_cn/3_exist_data_new_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,13 @@ mmdetection
│ │ ├── VOC2012
```

你也可以通过如下方式设定数据集根路径
```bash
export MMDET_DATASETS=$data_root
```
我们将会使用环境便变量 `$MMDET_DATASETS` 作为数据集的根目录,因此你无需再修改相应配置文件的路径信息。


你需要使用脚本 `tools/dataset_converters/cityscapes.py` 将 cityscapes 标注转化为 coco 标注格式。

```shell
Expand Down
5 changes: 3 additions & 2 deletions mmdet/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .collect_env import collect_env
from .logger import get_caller_name, get_root_logger, log_img_scale
from .misc import find_latest_checkpoint
from .misc import find_latest_checkpoint, update_data_root
from .setup_env import setup_multi_processes

__all__ = [
'get_root_logger', 'collect_env', 'find_latest_checkpoint',
'setup_multi_processes', 'get_caller_name', 'log_img_scale'
'update_data_root', 'setup_multi_processes', 'get_caller_name',
'log_img_scale'
]
38 changes: 38 additions & 0 deletions mmdet/utils/misc.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
import glob
import os
import os.path as osp
import warnings

import mmcv
from mmcv.utils import print_log


def find_latest_checkpoint(path, suffix='pth'):
"""Find the latest checkpoint from the working directory.
Expand Down Expand Up @@ -36,3 +40,37 @@ def find_latest_checkpoint(path, suffix='pth'):
latest = count
latest_path = checkpoint
return latest_path


def update_data_root(cfg, logger=None):
"""Update data root according to env MMDET_DATASETS.
If set env MMDET_DATASETS, update cfg.data_root according to
MMDET_DATASETS. Otherwise, using cfg.data_root as default.
Args:
cfg (mmcv.Config): The model config need to modify
logger (logging.Logger | str | None): the way to print msg
"""
assert isinstance(cfg, mmcv.Config), \
f'cfg got wrong type: {type(cfg)}, expected mmcv.Config'

if 'MMDET_DATASETS' in os.environ:
dst_root = os.environ['MMDET_DATASETS']
print_log(f'MMDET_DATASETS has been set to be {dst_root}.'
f'Using {dst_root} as data root.')
else:
return

assert isinstance(cfg, mmcv.Config), \
f'cfg got wrong type: {type(cfg)}, expected mmcv.Config'

def update(cfg, src_str, dst_str):
for k, v in cfg.items():
if isinstance(v, mmcv.ConfigDict):
update(cfg[k], src_str, dst_str)
if isinstance(v, str) and src_str in v:
cfg[k] = v.replace(src_str, dst_str)

update(cfg.data, cfg.data_root, dst_root)
cfg.data_root = dst_root
5 changes: 5 additions & 0 deletions tools/analysis_tools/analyze_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from mmdet.core.evaluation import eval_map
from mmdet.core.visualization import imshow_gt_det_bboxes
from mmdet.datasets import build_dataset, get_loading_pipeline
from mmdet.utils import update_data_root


def bbox_map_eval(det_result, annotation):
Expand Down Expand Up @@ -186,6 +187,10 @@ def main():
mmcv.check_file_exist(args.prediction_path)

cfg = Config.fromfile(args.config)

# update data root according to MMDET_DATASETS
update_data_root(cfg)

if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
cfg.data.test.test_mode = True
Expand Down
5 changes: 5 additions & 0 deletions tools/analysis_tools/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from mmdet.datasets import (build_dataloader, build_dataset,
replace_ImageToTensor)
from mmdet.models import build_detector
from mmdet.utils import update_data_root


def parse_args():
Expand Down Expand Up @@ -170,6 +171,10 @@ def main():
args = parse_args()

cfg = Config.fromfile(args.config)

# update data root according to MMDET_DATASETS
update_data_root(cfg)

if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)

Expand Down
5 changes: 5 additions & 0 deletions tools/analysis_tools/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from mmdet.core.evaluation.bbox_overlaps import bbox_overlaps
from mmdet.datasets import build_dataset
from mmdet.utils import update_data_root


def parse_args():
Expand Down Expand Up @@ -230,6 +231,10 @@ def main():
args = parse_args()

cfg = Config.fromfile(args.config)

# update data root according to MMDET_DATASETS
update_data_root(cfg)

if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)

Expand Down
5 changes: 5 additions & 0 deletions tools/analysis_tools/eval_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from mmcv import Config, DictAction

from mmdet.datasets import build_dataset
from mmdet.utils import update_data_root


def parse_args():
Expand Down Expand Up @@ -48,6 +49,10 @@ def main():
args = parse_args()

cfg = Config.fromfile(args.config)

# update data root according to MMDET_DATASETS
update_data_root(cfg)

assert args.eval or args.format_only, (
'Please specify at least one operation (eval/format the results) with '
'the argument "--eval", "--format-only"')
Expand Down
5 changes: 4 additions & 1 deletion tools/analysis_tools/optimize_anchors.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

from mmdet.core import bbox_cxcywh_to_xyxy, bbox_overlaps, bbox_xyxy_to_cxcywh
from mmdet.datasets import build_dataset
from mmdet.utils import get_root_logger
from mmdet.utils import get_root_logger, update_data_root


def parse_args():
Expand Down Expand Up @@ -325,6 +325,9 @@ def main():
cfg = args.config
cfg = Config.fromfile(cfg)

# update data root according to MMDET_DATASETS
update_data_root(cfg)

input_shape = args.input_shape
assert len(input_shape) == 2

Expand Down
5 changes: 5 additions & 0 deletions tools/misc/browse_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from mmdet.core.utils import mask2ndarray
from mmdet.core.visualization import imshow_det_bboxes
from mmdet.datasets.builder import build_dataset
from mmdet.utils import update_data_root


def parse_args():
Expand Down Expand Up @@ -55,6 +56,10 @@ def skip_pipeline_steps(config):
]

cfg = Config.fromfile(config_path)

# update data root according to MMDET_DATASETS
update_data_root(cfg)

if cfg_options is not None:
cfg.merge_from_dict(cfg_options)
train_data_cfg = cfg.data.train
Expand Down
6 changes: 6 additions & 0 deletions tools/misc/print_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from mmcv import Config, DictAction

from mmdet.utils import update_data_root


def parse_args():
parser = argparse.ArgumentParser(description='Print the whole config')
Expand Down Expand Up @@ -42,6 +44,10 @@ def main():
args = parse_args()

cfg = Config.fromfile(args.config)

# update data root according to MMDET_DATASETS
update_data_root(cfg)

if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
print(f'Config:\n{cfg.pretty_text}')
Expand Down
6 changes: 5 additions & 1 deletion tools/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from mmdet.apis import multi_gpu_test, single_gpu_test
from mmdet.datasets import build_dataloader, build_dataset
from mmdet.models import build_detector
from mmdet.utils import setup_multi_processes
from mmdet.utils import setup_multi_processes, update_data_root


def parse_args():
Expand Down Expand Up @@ -131,6 +131,10 @@ def main():
raise ValueError('The output file must be a pkl file.')

cfg = Config.fromfile(args.config)

# update data root according to MMDET_DATASETS
update_data_root(cfg)

if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)

Expand Down
7 changes: 6 additions & 1 deletion tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from mmdet.apis import init_random_seed, set_random_seed, train_detector
from mmdet.datasets import build_dataset
from mmdet.models import build_detector
from mmdet.utils import collect_env, get_root_logger, setup_multi_processes
from mmdet.utils import (collect_env, get_root_logger, setup_multi_processes,
update_data_root)


def parse_args():
Expand Down Expand Up @@ -103,6 +104,10 @@ def main():
args = parse_args()

cfg = Config.fromfile(args.config)

# update data root according to MMDET_DATASETS
update_data_root(cfg)

if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)

Expand Down

0 comments on commit 14b1b4e

Please sign in to comment.