diff --git a/mmpose/apis/train.py b/mmpose/apis/train.py index 9ed087d0ce..2fd17713b2 100644 --- a/mmpose/apis/train.py +++ b/mmpose/apis/train.py @@ -1,12 +1,14 @@ # Copyright (c) OpenMMLab. All rights reserved. import warnings +import mmcv import numpy as np import torch import torch.distributed as dist from mmcv.parallel import MMDataParallel, MMDistributedDataParallel from mmcv.runner import (DistSamplerSeedHook, EpochBasedRunner, OptimizerHook, get_dist_info) +from mmcv.utils import digit_version from mmpose.core import DistEvalHook, EvalHook, build_optimizers from mmpose.core.distributed_wrapper import DistributedDataParallelWrapper @@ -130,7 +132,14 @@ def train_model(model, broadcast_buffers=False, find_unused_parameters=find_unused_parameters) else: - model = MMDataParallel(model, device_ids=cfg.gpu_ids) + if digit_version(mmcv.__version__) >= digit_version( + '1.4.4') or torch.cuda.is_available(): + model = MMDataParallel(model, device_ids=cfg.gpu_ids) + else: + warnings.warn( + 'We recommend to use MMCV >= 1.4.4 for CPU training. ' + 'See https://github.com/open-mmlab/mmpose/pull/1157 for ' + 'details.') # build runner optimizer = build_optimizers(model, cfg.optimizer)