diff --git a/mmedit/apis/train.py b/mmedit/apis/train.py index d577e6357d..d00ddd8bf0 100644 --- a/mmedit/apis/train.py +++ b/mmedit/apis/train.py @@ -8,6 +8,7 @@ import torch from mmcv.parallel import MMDataParallel from mmcv.runner import HOOKS, IterBasedRunner +from mmcv.utils import build_from_cfg from mmedit.core import DistEvalIterHook, EvalIterHook, build_optimizers from mmedit.core.distributed_wrapper import DistributedDataParallelWrapper @@ -182,6 +183,20 @@ def _dist_train(model, DistEvalIterHook( data_loader, save_path=save_path, **cfg.evaluation)) + # user-defined hooks + if cfg.get('custom_hooks', None): + custom_hooks = cfg.custom_hooks + assert isinstance(custom_hooks, list), \ + f'custom_hooks expect list type, but got {type(custom_hooks)}' + for hook_cfg in cfg.custom_hooks: + assert isinstance(hook_cfg, dict), \ + 'Each item in custom_hooks expects dict type, but got ' \ + f'{type(hook_cfg)}' + hook_cfg = hook_cfg.copy() + priority = hook_cfg.pop('priority', 'NORMAL') + hook = build_from_cfg(hook_cfg, HOOKS) + runner.register_hook(hook, priority=priority) + if cfg.resume_from: runner.resume(cfg.resume_from) elif cfg.load_from: @@ -291,6 +306,20 @@ def _non_dist_train(model, runner.register_hook( EvalIterHook(data_loader, save_path=save_path, **cfg.evaluation)) + # user-defined hooks + if cfg.get('custom_hooks', None): + custom_hooks = cfg.custom_hooks + assert isinstance(custom_hooks, list), \ + f'custom_hooks expect list type, but got {type(custom_hooks)}' + for hook_cfg in cfg.custom_hooks: + assert isinstance(hook_cfg, dict), \ + 'Each item in custom_hooks expects dict type, but got ' \ + f'{type(hook_cfg)}' + hook_cfg = hook_cfg.copy() + priority = hook_cfg.pop('priority', 'NORMAL') + hook = build_from_cfg(hook_cfg, HOOKS) + runner.register_hook(hook, priority=priority) + if cfg.resume_from: runner.resume(cfg.resume_from) elif cfg.load_from: