Skip to content

Commit

Permalink
[Enhacemnet] api train support cpu training for mmcv<1.4.4 (#1161)
Browse files Browse the repository at this point in the history
co-authored-by: ly015 <[email protected]>
  • Loading branch information
EasonQYS and ly015 committed Mar 2, 2022
1 parent 23d671e commit 33434b4
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion mmpose/apis/train.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings

import mmcv
import numpy as np
import torch
import torch.distributed as dist
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (DistSamplerSeedHook, EpochBasedRunner, OptimizerHook,
get_dist_info)
from mmcv.utils import digit_version

from mmpose.core import DistEvalHook, EvalHook, build_optimizers
from mmpose.core.distributed_wrapper import DistributedDataParallelWrapper
Expand Down Expand Up @@ -130,7 +132,14 @@ def train_model(model,
broadcast_buffers=False,
find_unused_parameters=find_unused_parameters)
else:
model = MMDataParallel(model, device_ids=cfg.gpu_ids)
if digit_version(mmcv.__version__) >= digit_version(
'1.4.4') or torch.cuda.is_available():
model = MMDataParallel(model, device_ids=cfg.gpu_ids)
else:
warnings.warn(
'We recommend to use MMCV >= 1.4.4 for CPU training. '
'See https://github.com/open-mmlab/mmpose/pull/1157 for '
'details.')

# build runner
optimizer = build_optimizers(model, cfg.optimizer)
Expand Down

0 comments on commit 33434b4

Please sign in to comment.