Skip to content

Commit

Permalink
Add DistSamplerSeedHook for when runner is EpochBasedRunner (open-mml…
Browse files Browse the repository at this point in the history
…ab#1449)

* Add DistSamplerSeedHook for when runner is EpochBasedRunner

* add comment
  • Loading branch information
MeowZheng authored Apr 6, 2022
1 parent 5496168 commit 4bc2a30
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion mmseg/apis/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import torch
import torch.distributed as dist
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import HOOKS, build_optimizer, build_runner, get_dist_info
from mmcv.runner import (HOOKS, DistSamplerSeedHook, EpochBasedRunner,
build_optimizer, build_runner, get_dist_info)
from mmcv.utils import build_from_cfg

from mmseg import digit_version
Expand Down Expand Up @@ -128,6 +129,12 @@ def train_segmentor(model,
runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config,
cfg.checkpoint_config, cfg.log_config,
cfg.get('momentum_config', None))
if distributed:
# when distributed training by epoch, using`DistSamplerSeedHook` to set
# the different seed to distributed sampler for each epoch, it will
# shuffle dataset at each epoch and avoid overfitting.
if isinstance(runner, EpochBasedRunner):
runner.register_hook(DistSamplerSeedHook())

# an ugly walkaround to make the .log and .log.json filenames the same
runner.timestamp = timestamp
Expand Down

0 comments on commit 4bc2a30

Please sign in to comment.