Skip to content

Commit

Permalink
fix import and replace print with warnings.warn
Browse files Browse the repository at this point in the history
  • Loading branch information
ly015 committed Mar 1, 2022
1 parent fe3b2e3 commit cd89d70
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions mmpose/apis/train.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings

import mmcv
import numpy as np
import torch
import torch.distributed as dist
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (DistSamplerSeedHook, EpochBasedRunner, OptimizerHook,
get_dist_info)
from mmcv.utils import digit_version

from mmpose.core import DistEvalHook, EvalHook, build_optimizers
from mmpose.core.distributed_wrapper import DistributedDataParallelWrapper
Expand Down Expand Up @@ -130,15 +132,15 @@ def train_model(model,
broadcast_buffers=False,
find_unused_parameters=find_unused_parameters)
else:
if digit_version(mmcv.__version__) >= digit_version('1.4.4'):
model = MMDataParallel(model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)
elif torch.cuda.is_available():
model = MMDataParallel(model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)
if digit_version(mmcv.__version__) >= digit_version(
'1.4.4') or torch.cuda.is_available():
model = MMDataParallel(
model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)
else:
print('Recommand to use MMCV >= 1.4.4 for CPU training!')
print('Now we are using an earlier version for CPU training!')
model = model.cpu()

warnings.warn(
'We recommend to use MMCV >= 1.4.4 for CPU training. '
'See https://github.com/open-mmlab/mmpose/pull/1157 for '
'details.')

# build runner
optimizer = build_optimizers(model, cfg.optimizer)
Expand Down

0 comments on commit cd89d70

Please sign in to comment.