Skip to content

Commit

Permalink
Merge branch 'dev' into mask2former
Browse files Browse the repository at this point in the history
  • Loading branch information
chhluo authored Mar 19, 2022
2 parents 6f329b5 + 8603765 commit a873d1b
Show file tree
Hide file tree
Showing 14 changed files with 174 additions and 27 deletions.
19 changes: 16 additions & 3 deletions docs/en/1_exist_data_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -584,10 +584,23 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 PORT=29500 ./tools/dist_train.sh ${CONFIG_FILE} 4
CUDA_VISIBLE_DEVICES=4,5,6,7 PORT=29501 ./tools/dist_train.sh ${CONFIG_FILE} 4
```
### Training on multiple nodes
### Train with multiple machines
MMDetection relies on `torch.distributed` package for distributed training.
Thus, as a basic usage, one can launch distributed training via PyTorch's [launch utility](https://pytorch.org/docs/stable/distributed.html#launch-utility).
If you launch with multiple machines simply connected with ethernet, you can simply run following commands:
On the first machine:
```shell
NNODES=2 NODE_RANK=0 PORT=$MASTER_PORT MASTER_ADDR=$MASTER_ADDR sh tools/dist_train.sh $CONFIG $GPUS
```
On the second machine:
```shell
NNODES=2 NODE_RANK=1 PORT=$MASTER_PORT MASTER_ADDR=$MASTER_ADDR sh tools/dist_train.sh $CONFIG $GPUS
```
Usually it is slow if you do not have high speed networking like InfiniBand.
### Manage jobs with Slurm
Expand Down
20 changes: 17 additions & 3 deletions docs/zh_cn/1_exist_data_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -566,11 +566,25 @@ CUDA_VISIBLE_DEVICES=0,1,2,3 PORT=29500 ./tools/dist_train.sh ${CONFIG_FILE} 4
CUDA_VISIBLE_DEVICES=4,5,6,7 PORT=29501 ./tools/dist_train.sh ${CONFIG_FILE} 4
```

#### 在多个节点上训练
### 使用多台机器训练

MMDetection 是依赖 `torch.distributed` 包进行分布式训练的。因此,我们可以通过 PyTorch 的 [启动工具](https://pytorch.org/docs/stable/distributed.html#launch-utility) 来进行基本地使用。
如果您想使用由 ethernet 连接起来的多台机器, 您可以使用以下命令:

#### 使用 Slurm 来管理任务
在第一台机器上:

```shell
NNODES=2 NODE_RANK=0 PORT=$MASTER_PORT MASTER_ADDR=$MASTER_ADDR sh tools/dist_train.sh $CONFIG $GPUS
```

在第二台机器上:

```shell
NNODES=2 NODE_RANK=1 PORT=$MASTER_PORT MASTER_ADDR=$MASTER_ADDR sh tools/dist_train.sh $CONFIG $GPUS
```

但是,如果您不使用高速网路连接这几台机器的话,训练将会非常慢。

### 使用 Slurm 来管理任务

Slurm 是一个常见的计算集群调度系统。在 Slurm 管理的集群上,你可以使用 `slurm.sh` 来开启训练任务。它既支持单节点训练也支持多节点训练。

Expand Down
4 changes: 2 additions & 2 deletions mmdet/core/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .dist_utils import (DistOptimizerHook, all_reduce_dict, allreduce_grads,
reduce_mean)
reduce_mean, sync_random_seed)
from .misc import (center_of_mass, filter_scores_and_topk, flip_tensor,
generate_coordinate, mask2ndarray, multi_apply,
select_single_mlvl, unmap)
Expand All @@ -9,5 +9,5 @@
'allreduce_grads', 'DistOptimizerHook', 'reduce_mean', 'multi_apply',
'unmap', 'mask2ndarray', 'flip_tensor', 'all_reduce_dict',
'center_of_mass', 'generate_coordinate', 'select_single_mlvl',
'filter_scores_and_topk'
'filter_scores_and_topk', 'sync_random_seed'
]
40 changes: 40 additions & 0 deletions mmdet/core/utils/dist_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import warnings
from collections import OrderedDict

import numpy as np
import torch
import torch.distributed as dist
from mmcv.runner import OptimizerHook, get_dist_info
Expand Down Expand Up @@ -151,3 +152,42 @@ def all_reduce_dict(py_dict, op='sum', group=None, to_float=True):
if isinstance(py_dict, OrderedDict):
out_dict = OrderedDict(out_dict)
return out_dict


def sync_random_seed(seed=None, device='cuda'):
"""Make sure different ranks share the same seed.
All workers must call this function, otherwise it will deadlock.
This method is generally used in `DistributedSampler`,
because the seed should be identical across all processes
in the distributed group.
In distributed sampling, different ranks should sample non-overlapped
data in the dataset. Therefore, this function is used to make sure that
each rank shuffles the data indices in the same order based
on the same seed. Then different ranks could use different indices
to select non-overlapped data from the same data list.
Args:
seed (int, Optional): The seed. Default to None.
device (str): The device where the seed will be put on.
Default to 'cuda'.
Returns:
int: Seed to be used.
"""
if seed is None:
seed = np.random.randint(2**31)
assert isinstance(seed, int)

rank, world_size = get_dist_info()

if world_size == 1:
return seed

if rank == 0:
random_num = torch.tensor(seed, dtype=torch.int32, device=device)
else:
random_num = torch.tensor(0, dtype=torch.int32, device=device)
dist.broadcast(random_num, src=0)
return random_num.item()
2 changes: 2 additions & 0 deletions mmdet/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from functools import partial

import numpy as np
import torch
from mmcv.parallel import collate
from mmcv.runner import get_dist_info
from mmcv.utils import TORCH_VERSION, Registry, build_from_cfg, digit_version
Expand Down Expand Up @@ -197,3 +198,4 @@ def worker_init_fn(worker_id, num_workers, rank, seed):
worker_seed = num_workers * rank + worker_id + seed
np.random.seed(worker_seed)
random.seed(worker_seed)
torch.manual_seed(worker_seed)
4 changes: 2 additions & 2 deletions mmdet/datasets/pipelines/instaboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def _parse_anns(self, results, anns, img):

def __call__(self, results):
img = results['img']
orig_type = img.dtype
ori_type = img.dtype
anns = self._load_anns(results)
if np.random.choice([0, 1], p=[1 - self.aug_ratio, self.aug_ratio]):
try:
Expand All @@ -109,7 +109,7 @@ def __call__(self, results):
anns, img = instaboost.get_new_data(
anns, img.astype(np.uint8), self.cfg, background=None)

results = self._parse_anns(results, anns, img.astype(orig_type))
results = self._parse_anns(results, anns, img.astype(ori_type))
return results

def __repr__(self):
Expand Down
6 changes: 3 additions & 3 deletions mmdet/datasets/pipelines/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,9 +520,9 @@ def __call__(self, results):
random_shift_y = random.randint(-self.max_shift_px,
self.max_shift_px)
new_x = max(0, random_shift_x)
orig_x = max(0, -random_shift_x)
ori_x = max(0, -random_shift_x)
new_y = max(0, random_shift_y)
orig_y = max(0, -random_shift_y)
ori_y = max(0, -random_shift_y)

# TODO: support mask and semantic segmentation maps.
for key in results.get('bbox_fields', []):
Expand Down Expand Up @@ -558,7 +558,7 @@ def __call__(self, results):
new_h = img_h - np.abs(random_shift_y)
new_w = img_w - np.abs(random_shift_x)
new_img[new_y:new_y + new_h, new_x:new_x + new_w] \
= img[orig_y:orig_y + new_h, orig_x:orig_x + new_w]
= img[ori_y:ori_y + new_h, ori_x:ori_x + new_w]
results[key] = new_img

return results
Expand Down
16 changes: 14 additions & 2 deletions mmdet/datasets/samplers/distributed_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import torch
from torch.utils.data import DistributedSampler as _DistributedSampler

from mmdet.core.utils import sync_random_seed


class DistributedSampler(_DistributedSampler):

Expand All @@ -15,13 +17,23 @@ def __init__(self,
seed=0):
super().__init__(
dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
# for the compatibility from PyTorch 1.3+
self.seed = seed if seed is not None else 0

# In distributed sampling, different ranks should sample
# non-overlapped data in the dataset. Therefore, this function
# is used to make sure that each rank shuffles the data indices
# in the same order based on the same seed. Then different ranks
# could use different indices to select non-overlapped data from the
# same data list.
self.seed = sync_random_seed(seed)

def __iter__(self):
# deterministically shuffle based on epoch
if self.shuffle:
g = torch.Generator()
# When :attr:`shuffle=True`, this ensures all replicas
# use a different random ordering for each epoch.
# Otherwise, the next iteration of this sampler will
# yield the same ordering.
g.manual_seed(self.epoch + self.seed)
indices = torch.randperm(len(self.dataset), generator=g).tolist()
else:
Expand Down
18 changes: 16 additions & 2 deletions mmdet/datasets/samplers/infinite_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from mmcv.runner import get_dist_info
from torch.utils.data.sampler import Sampler

from mmdet.core.utils import sync_random_seed


class InfiniteGroupBatchSampler(Sampler):
"""Similar to `BatchSampler` warping a `GroupSampler. It is designed for
Expand Down Expand Up @@ -48,7 +50,13 @@ def __init__(self,
self.world_size = world_size
self.dataset = dataset
self.batch_size = batch_size
self.seed = seed if seed is not None else 0
# In distributed sampling, different ranks should sample
# non-overlapped data in the dataset. Therefore, this function
# is used to make sure that each rank shuffles the data indices
# in the same order based on the same seed. Then different ranks
# could use different indices to select non-overlapped data from the
# same data list.
self.seed = sync_random_seed(seed)
self.shuffle = shuffle

assert hasattr(self.dataset, 'flag')
Expand Down Expand Up @@ -133,7 +141,13 @@ def __init__(self,
self.world_size = world_size
self.dataset = dataset
self.batch_size = batch_size
self.seed = seed if seed is not None else 0
# In distributed sampling, different ranks should sample
# non-overlapped data in the dataset. Therefore, this function
# is used to make sure that each rank shuffles the data indices
# in the same order based on the same seed. Then different ranks
# could use different indices to select non-overlapped data from the
# same data list.
self.seed = sync_random_seed(seed)
self.shuffle = shuffle
self.size = len(dataset)
self.indices = self._indices_of_rank()
Expand Down
29 changes: 26 additions & 3 deletions tools/analysis_tools/analyze_logs.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ def cal_train_time(log_dicts, args):
all_times.append(log_dict[epoch]['time'])
else:
all_times.append(log_dict[epoch]['time'][1:])
if not all_times:
raise KeyError(
'Please reduce the log interval in the config so that'
'interval is less than iterations of one epoch.')
all_times = np.array(all_times)
epoch_ave_time = all_times.mean(-1)
slowest_epoch = epoch_ave_time.argmax()
Expand Down Expand Up @@ -50,12 +54,21 @@ def plot_curve(log_dicts, args):
epochs = list(log_dict.keys())
for j, metric in enumerate(metrics):
print(f'plot curve of {args.json_logs[i]}, metric is {metric}')
if metric not in log_dict[epochs[0]]:
if metric not in log_dict[epochs[int(args.start_epoch) - 1]]:
if 'mAP' in metric:
raise KeyError(
f'{args.json_logs[i]} does not contain metric '
f'{metric}. Please check if "--no-validate" is '
'specified when you trained the model.')
raise KeyError(
f'{args.json_logs[i]} does not contain metric {metric}')
f'{args.json_logs[i]} does not contain metric {metric}. '
'Please reduce the log interval in the config so that '
'interval is less than iterations of one epoch.')

if 'mAP' in metric:
xs = np.arange(1, max(epochs) + 1)
xs = np.arange(
int(args.start_epoch),
max(epochs) + 1, int(args.eval_interval))
ys = []
for epoch in epochs:
ys += log_dict[epoch][metric]
Expand Down Expand Up @@ -104,6 +117,16 @@ def add_plot_parser(subparsers):
nargs='+',
default=['bbox_mAP'],
help='the metric that you want to plot')
parser_plt.add_argument(
'--start-epoch',
type=str,
default='1',
help='the epoch that you want to start')
parser_plt.add_argument(
'--eval-interval',
type=str,
default='1',
help='the eval interval when training')
parser_plt.add_argument('--title', type=str, help='title of figure')
parser_plt.add_argument(
'--legend',
Expand Down
6 changes: 3 additions & 3 deletions tools/analysis_tools/get_flops.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def main():
h, w = args.shape
else:
raise ValueError('invalid input shape')
orig_shape = (3, h, w)
ori_shape = (3, h, w)
divisor = args.size_divisor
if divisor > 0:
h = int(np.ceil(h / divisor)) * divisor
Expand Down Expand Up @@ -83,9 +83,9 @@ def main():
split_line = '=' * 30

if divisor > 0 and \
input_shape != orig_shape:
input_shape != ori_shape:
print(f'{split_line}\nUse size divisor set input shape '
f'from {orig_shape} to {input_shape}\n')
f'from {ori_shape} to {input_shape}\n')
print(f'{split_line}\nInput shape: {input_shape}\n'
f'Flops: {flops}\nParams: {params}\n{split_line}')
print('!!!Please be cautious if you use the results in papers. '
Expand Down
16 changes: 14 additions & 2 deletions tools/dist_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,20 @@
CONFIG=$1
CHECKPOINT=$2
GPUS=$3
NNODES=${NNODES:-1}
NODE_RANK=${NODE_RANK:-0}
PORT=${PORT:-29500}
MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}

PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \
$(dirname "$0")/test.py $CONFIG $CHECKPOINT --launcher pytorch ${@:4}
python -m torch.distributed.launch \
--nnodes=$NNODES \
--node_rank=$NODE_RANK \
--master_addr=$MASTER_ADDR \
--nproc_per_node=$GPUS \
--master_port=$PORT \
$(dirname "$0")/test.py \
$CONFIG \
$CHECKPOINT \
--launcher pytorch \
${@:4}
15 changes: 13 additions & 2 deletions tools/dist_train.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,19 @@

CONFIG=$1
GPUS=$2
NNODES=${NNODES:-1}
NODE_RANK=${NODE_RANK:-0}
PORT=${PORT:-29500}
MASTER_ADDR=${MASTER_ADDR:-"127.0.0.1"}

PYTHONPATH="$(dirname $0)/..":$PYTHONPATH \
python -m torch.distributed.launch --nproc_per_node=$GPUS --master_port=$PORT \
$(dirname "$0")/train.py $CONFIG --launcher pytorch ${@:3}
python -m torch.distributed.launch \
--nnodes=$NNODES \
--node_rank=$NODE_RANK \
--master_addr=$MASTER_ADDR \
--nproc_per_node=$GPUS \
--master_port=$PORT \
$(dirname "$0")/train.py \
$CONFIG \
--seed 0 \
--launcher pytorch ${@:3}
6 changes: 6 additions & 0 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import mmcv
import torch
import torch.distributed as dist
from mmcv import Config, DictAction
from mmcv.runner import get_dist_info, init_dist
from mmcv.utils import get_git_hash
Expand Down Expand Up @@ -52,6 +53,10 @@ def parse_args():
help='id of gpu to use '
'(only applicable to non-distributed training)')
parser.add_argument('--seed', type=int, default=None, help='random seed')
parser.add_argument(
'--diff-seed',
action='store_true',
help='Whether or not set different seeds for different ranks')
parser.add_argument(
'--deterministic',
action='store_true',
Expand Down Expand Up @@ -169,6 +174,7 @@ def main():

# set random seeds
seed = init_random_seed(args.seed)
seed = seed + dist.get_rank() if args.diff_seed else seed
logger.info(f'Set random seed to {seed}, '
f'deterministic: {args.deterministic}')
set_random_seed(seed, deterministic=args.deterministic)
Expand Down

0 comments on commit a873d1b

Please sign in to comment.