diff --git a/mmpose/apis/train.py b/mmpose/apis/train.py index c3180e97ab..b539da7906 100644 --- a/mmpose/apis/train.py +++ b/mmpose/apis/train.py @@ -1,12 +1,14 @@ # Copyright (c) OpenMMLab. All rights reserved. import warnings +import mmcv import numpy as np import torch import torch.distributed as dist from mmcv.parallel import MMDataParallel, MMDistributedDataParallel from mmcv.runner import (DistSamplerSeedHook, EpochBasedRunner, OptimizerHook, get_dist_info) +from mmcv.utils import digit_version from mmpose.core import DistEvalHook, EvalHook, build_optimizers from mmpose.core.distributed_wrapper import DistributedDataParallelWrapper @@ -130,15 +132,15 @@ def train_model(model, broadcast_buffers=False, find_unused_parameters=find_unused_parameters) else: - if digit_version(mmcv.__version__) >= digit_version('1.4.4'): - model = MMDataParallel(model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids) - elif torch.cuda.is_available(): - model = MMDataParallel(model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids) + if digit_version(mmcv.__version__) >= digit_version( + '1.4.4') or torch.cuda.is_available(): + model = MMDataParallel( + model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids) else: - print('Recommand to use MMCV >= 1.4.4 for CPU training!') - print('Now we are using an earlier version for CPU training!') - model = model.cpu() - + warnings.warn( + 'We recommend to use MMCV >= 1.4.4 for CPU training. ' + 'See https://github.com/open-mmlab/mmpose/pull/1157 for ' + 'details.') # build runner optimizer = build_optimizers(model, cfg.optimizer)