Skip to content

Commit

Permalink
[Enhacemnet] api train support cpu training for mmcv<1.4.4 (#1161)
Browse files Browse the repository at this point in the history
* [Enhace] api train support cpu training

[Enhace] api train support cpu training

* update for cpu trainning and add a note for mmcv<1.4.4

* support for mmcv < 1.4.4 for cpu training

* support cpu trainning by cuda.is_available

* fix import and replace print with warnings.warn

* fix bug

Co-authored-by: ly015 <[email protected]>
  • Loading branch information
EasonQYS and ly015 authored Mar 2, 2022
1 parent 23d671e commit b618970
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion mmpose/apis/train.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings

import mmcv
import numpy as np
import torch
import torch.distributed as dist
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (DistSamplerSeedHook, EpochBasedRunner, OptimizerHook,
get_dist_info)
from mmcv.utils import digit_version

from mmpose.core import DistEvalHook, EvalHook, build_optimizers
from mmpose.core.distributed_wrapper import DistributedDataParallelWrapper
Expand Down Expand Up @@ -130,7 +132,14 @@ def train_model(model,
broadcast_buffers=False,
find_unused_parameters=find_unused_parameters)
else:
model = MMDataParallel(model, device_ids=cfg.gpu_ids)
if digit_version(mmcv.__version__) >= digit_version(
'1.4.4') or torch.cuda.is_available():
model = MMDataParallel(model, device_ids=cfg.gpu_ids)
else:
warnings.warn(
'We recommend to use MMCV >= 1.4.4 for CPU training. '
'See https://github.com/open-mmlab/mmpose/pull/1157 for '
'details.')

# build runner
optimizer = build_optimizers(model, cfg.optimizer)
Expand Down

0 comments on commit b618970

Please sign in to comment.