From 9079b9135925d35984fac5c9f00a70ba8f30bd80 Mon Sep 17 00:00:00 2001 From: Kai Chen Date: Sat, 21 Dec 2019 22:15:26 +0800 Subject: [PATCH] use mmcv.init_dist --- .isort.cfg | 2 +- .pre-commit-config.yaml | 4 +-- mmdet/apis/__init__.py | 9 +++--- mmdet/apis/env.py | 69 ---------------------------------------- mmdet/apis/train.py | 27 ++++++++++++++-- requirements.txt | 10 +++--- tools/test.py | 3 +- tools/test_robustness.py | 4 +-- tools/train.py | 4 +-- 9 files changed, 41 insertions(+), 91 deletions(-) delete mode 100644 mmdet/apis/env.py diff --git a/.isort.cfg b/.isort.cfg index 2186a18b54d..e790e3ee0ce 100644 --- a/.isort.cfg +++ b/.isort.cfg @@ -3,6 +3,6 @@ line_length = 79 multi_line_output = 0 known_standard_library = setuptools known_first_party = mmdet -known_third_party = Cython,albumentations,cv2,imagecorruptions,matplotlib,mmcv,numpy,pycocotools,robustness_eval,roi_align,roi_pool,seaborn,six,terminaltables,torch,torchvision +known_third_party = Cython,albumentations,asynctest,cv2,imagecorruptions,matplotlib,mmcv,numpy,pycocotools,robustness_eval,roi_align,roi_pool,seaborn,six,terminaltables,torch,torchvision no_lines_before = STDLIB,LOCALFOLDER default_section = THIRDPARTY diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2fae06cd174..901104c2cc1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,11 +8,11 @@ repos: hooks: - id: isort - repo: https://github.com/pre-commit/mirrors-yapf - rev: 80b9cd2f0f3b1f3456a77eff3ddbaf08f18c08ae + rev: v0.29.0 hooks: - id: yapf - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v2.3.0 + rev: v2.4.0 hooks: - id: flake8 - id: trailing-whitespace diff --git a/mmdet/apis/__init__.py b/mmdet/apis/__init__.py index 4cdf847b25d..914307a7588 100644 --- a/mmdet/apis/__init__.py +++ b/mmdet/apis/__init__.py @@ -1,10 +1,9 @@ -from .env import get_root_logger, init_dist, set_random_seed from .inference import (async_inference_detector, inference_detector, init_detector, show_result, show_result_pyplot) -from .train import train_detector +from .train import get_root_logger, set_random_seed, train_detector __all__ = [ - 'async_inference_detector', 'init_dist', 'get_root_logger', - 'set_random_seed', 'train_detector', 'init_detector', 'inference_detector', - 'show_result', 'show_result_pyplot' + 'get_root_logger', 'set_random_seed', 'train_detector', 'init_detector', + 'async_inference_detector', 'inference_detector', 'show_result', + 'show_result_pyplot' ] diff --git a/mmdet/apis/env.py b/mmdet/apis/env.py deleted file mode 100644 index 19b0f86db13..00000000000 --- a/mmdet/apis/env.py +++ /dev/null @@ -1,69 +0,0 @@ -import logging -import os -import random -import subprocess - -import numpy as np -import torch -import torch.distributed as dist -import torch.multiprocessing as mp -from mmcv.runner import get_dist_info - - -def init_dist(launcher, backend='nccl', **kwargs): - if mp.get_start_method(allow_none=True) is None: - mp.set_start_method('spawn') - if launcher == 'pytorch': - _init_dist_pytorch(backend, **kwargs) - elif launcher == 'mpi': - _init_dist_mpi(backend, **kwargs) - elif launcher == 'slurm': - _init_dist_slurm(backend, **kwargs) - else: - raise ValueError('Invalid launcher type: {}'.format(launcher)) - - -def _init_dist_pytorch(backend, **kwargs): - # TODO: use local_rank instead of rank % num_gpus - rank = int(os.environ['RANK']) - num_gpus = torch.cuda.device_count() - torch.cuda.set_device(rank % num_gpus) - dist.init_process_group(backend=backend, **kwargs) - - -def _init_dist_mpi(backend, **kwargs): - raise NotImplementedError - - -def _init_dist_slurm(backend, port=29500, **kwargs): - proc_id = int(os.environ['SLURM_PROCID']) - ntasks = int(os.environ['SLURM_NTASKS']) - node_list = os.environ['SLURM_NODELIST'] - num_gpus = torch.cuda.device_count() - torch.cuda.set_device(proc_id % num_gpus) - addr = subprocess.getoutput( - 'scontrol show hostname {} | head -n1'.format(node_list)) - os.environ['MASTER_PORT'] = str(port) - os.environ['MASTER_ADDR'] = addr - os.environ['WORLD_SIZE'] = str(ntasks) - os.environ['RANK'] = str(proc_id) - dist.init_process_group(backend=backend) - - -def set_random_seed(seed): - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - - -def get_root_logger(log_level=logging.INFO): - logger = logging.getLogger() - if not logger.hasHandlers(): - logging.basicConfig( - format='%(asctime)s - %(levelname)s - %(message)s', - level=log_level) - rank, _ = get_dist_info() - if rank != 0: - logger.setLevel('ERROR') - return logger diff --git a/mmdet/apis/train.py b/mmdet/apis/train.py index cdac16d9989..47320c69eaa 100644 --- a/mmdet/apis/train.py +++ b/mmdet/apis/train.py @@ -1,18 +1,39 @@ -from __future__ import division +import logging +import random import re from collections import OrderedDict +import numpy as np import torch import torch.distributed as dist from mmcv.parallel import MMDataParallel, MMDistributedDataParallel -from mmcv.runner import DistSamplerSeedHook, Runner, obj_from_dict +from mmcv.runner import (DistSamplerSeedHook, Runner, get_dist_info, + obj_from_dict) from mmdet import datasets from mmdet.core import (CocoDistEvalmAPHook, CocoDistEvalRecallHook, DistEvalmAPHook, DistOptimizerHook, Fp16OptimizerHook) from mmdet.datasets import DATASETS, build_dataloader from mmdet.models import RPN -from .env import get_root_logger + + +def set_random_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def get_root_logger(log_level=logging.INFO): + logger = logging.getLogger() + if not logger.hasHandlers(): + logging.basicConfig( + format='%(asctime)s - %(levelname)s - %(message)s', + level=log_level) + rank, _ = get_dist_info() + if rank != 0: + logger.setLevel('ERROR') + return logger def parse_losses(losses): diff --git a/requirements.txt b/requirements.txt index 8a68f410d48..5cacde1c99f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,10 @@ -mmcv>=0.2.10 -numpy +albumentations>=0.3.2 +imagecorruptions matplotlib +mmcv>=0.2.15 +numpy +pycocotools six terminaltables -pycocotools torch>=1.1 torchvision -imagecorruptions -albumentations>=0.3.2 \ No newline at end of file diff --git a/tools/test.py b/tools/test.py index 64dd73377a2..b39cf13abde 100644 --- a/tools/test.py +++ b/tools/test.py @@ -9,9 +9,8 @@ import torch import torch.distributed as dist from mmcv.parallel import MMDataParallel, MMDistributedDataParallel -from mmcv.runner import get_dist_info, load_checkpoint +from mmcv.runner import get_dist_info, init_dist, load_checkpoint -from mmdet.apis import init_dist from mmdet.core import coco_eval, results2json, wrap_fp16_model from mmdet.datasets import build_dataloader, build_dataset from mmdet.models import build_detector diff --git a/tools/test_robustness.py b/tools/test_robustness.py index c0489f3ebaa..fb58deb952f 100644 --- a/tools/test_robustness.py +++ b/tools/test_robustness.py @@ -10,13 +10,13 @@ import torch import torch.distributed as dist from mmcv.parallel import MMDataParallel, MMDistributedDataParallel -from mmcv.runner import get_dist_info, load_checkpoint +from mmcv.runner import get_dist_info, init_dist, load_checkpoint from pycocotools.coco import COCO from pycocotools.cocoeval import COCOeval from robustness_eval import get_results from mmdet import datasets -from mmdet.apis import init_dist, set_random_seed +from mmdet.apis import set_random_seed from mmdet.core import (eval_map, fast_eval_recall, results2json, wrap_fp16_model) from mmdet.datasets import build_dataloader, build_dataset diff --git a/tools/train.py b/tools/train.py index c939343552d..e3bbbde6a0b 100644 --- a/tools/train.py +++ b/tools/train.py @@ -4,10 +4,10 @@ import torch from mmcv import Config +from mmcv.runner import init_dist from mmdet import __version__ -from mmdet.apis import (get_root_logger, init_dist, set_random_seed, - train_detector) +from mmdet.apis import get_root_logger, set_random_seed, train_detector from mmdet.datasets import build_dataset from mmdet.models import build_detector