Skip to content

Commit

Permalink
[Enhance] Continue to speed up training. (#6974)
Browse files Browse the repository at this point in the history
* [Enhance] Speed up training time.

* set in cfg
  • Loading branch information
RangiLyu authored Jan 17, 2022
1 parent 67f249e commit 549a556
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 2 deletions.
5 changes: 5 additions & 0 deletions configs/_base_/default_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,8 @@
load_from = None
resume_from = None
workflow = [('train', 1)]

# disable opencv multithreading to avoid system being overloaded
opencv_num_threads = 0
# set multi-process start method as `fork` to speed up the training
mp_start_method = 'fork'
40 changes: 38 additions & 2 deletions tools/train.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import copy
import multiprocessing as mp
import os
import os.path as osp
import platform
import time
import warnings

Expand All @@ -19,8 +21,6 @@
from mmdet.models import build_detector
from mmdet.utils import collect_env, get_root_logger

cv2.setNumThreads(0)


def parse_args():
parser = argparse.ArgumentParser(description='Train a detector')
Expand Down Expand Up @@ -91,12 +91,48 @@ def parse_args():
return args


def setup_multi_processes(cfg):
# set multi-process start method as `fork` to speed up the training
if platform.system() != 'Windows':
mp_start_method = cfg.get('mp_start_method', 'fork')
mp.set_start_method(mp_start_method)

# disable opencv multithreading to avoid system being overloaded
opencv_num_threads = cfg.get('opencv_num_threads', 0)
cv2.setNumThreads(opencv_num_threads)

# setup OMP threads
# This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa
if ('OMP_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1):
omp_num_threads = 1
warnings.warn(
f'Setting OMP_NUM_THREADS environment variable for each process '
f'to be {omp_num_threads} in default, to avoid your system being '
f'overloaded, please further tune the variable for optimal '
f'performance in your application as needed.')
os.environ['OMP_NUM_THREADS'] = str(omp_num_threads)

# setup MKL threads
if 'MKL_NUM_THREADS' not in os.environ and cfg.data.workers_per_gpu > 1:
mkl_num_threads = 1
warnings.warn(
f'Setting MKL_NUM_THREADS environment variable for each process '
f'to be {mkl_num_threads} in default, to avoid your system being '
f'overloaded, please further tune the variable for optimal '
f'performance in your application as needed.')
os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads)


def main():
args = parse_args()

cfg = Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)

# set multi-process settings
setup_multi_processes(cfg)

# set cudnn_benchmark
if cfg.get('cudnn_benchmark', False):
torch.backends.cudnn.benchmark = True
Expand Down

0 comments on commit 549a556

Please sign in to comment.