Skip to content

Commit

Permalink
[Fix] add compatibility for argument return_datasample (#2708)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ben-Louis authored Sep 20, 2023
1 parent fafcbec commit aa8ab14
Show file tree
Hide file tree
Showing 4 changed files with 47 additions and 16 deletions.
32 changes: 24 additions & 8 deletions mmpose/apis/inferencers/base_mmpose_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import logging
import mimetypes
import os
import warnings
from collections import defaultdict
from typing import (Callable, Dict, Generator, Iterable, List, Optional,
Sequence, Union)
Expand Down Expand Up @@ -44,7 +43,7 @@ class BaseMMPoseInferencer(BaseInferencer):
'return_vis', 'show', 'wait_time', 'draw_bbox', 'radius', 'thickness',
'kpt_thr', 'vis_out_dir', 'black_background'
}
postprocess_kwargs: set = {'pred_out_dir'}
postprocess_kwargs: set = {'pred_out_dir', 'return_datasample'}

def _load_weights_to_model(self, model: nn.Module,
checkpoint: Optional[dict],
Expand All @@ -67,15 +66,20 @@ def _load_weights_to_model(self, model: nn.Module,
# mmpose 1.x
model.dataset_meta = checkpoint_meta['dataset_meta']
else:
warnings.warn(
print_log(
'dataset_meta are not saved in the checkpoint\'s '
'meta data, load via config.')
'meta data, load via config.',
logger='current',
level=logging.WARNING)
model.dataset_meta = dataset_meta_from_config(
cfg, dataset_mode='train')
else:
warnings.warn('Checkpoint is not loaded, and the inference '
'result is calculated by the randomly initialized '
'model!')
print_log(
'Checkpoint is not loaded, and the inference '
'result is calculated by the randomly initialized '
'model!',
logger='current',
level=logging.WARNING)
model.dataset_meta = dataset_meta_from_config(
cfg, dataset_mode='train')

Expand Down Expand Up @@ -178,7 +182,10 @@ def _get_webcam_inputs(self, inputs: str) -> Generator:
# Attempt to open the video capture object.
vcap = cv2.VideoCapture(camera_id)
if not vcap.isOpened():
warnings.warn(f'Cannot open camera (ID={camera_id})')
print_log(
f'Cannot open camera (ID={camera_id})',
logger='current',
level=logging.WARNING)
return []

# Set video input flag and metadata.
Expand Down Expand Up @@ -384,6 +391,7 @@ def postprocess(
self,
preds: List[PoseDataSample],
visualization: List[np.ndarray],
return_datasample=None,
return_datasamples=False,
pred_out_dir: str = '',
) -> dict:
Expand Down Expand Up @@ -416,6 +424,14 @@ def postprocess(
json-serializable dict containing only basic data elements such
as strings and numbers.
"""
if return_datasample is not None:
print_log(
'The `return_datasample` argument is deprecated '
'and will be removed in future versions. Please '
'use `return_datasamples`.',
logger='current',
level=logging.WARNING)
return_datasamples = return_datasample

result_dict = defaultdict(list)

Expand Down
2 changes: 1 addition & 1 deletion mmpose/apis/inferencers/mmpose_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ class MMPoseInferencer(BaseMMPoseInferencer):
'kpt_thr', 'vis_out_dir', 'skeleton_style', 'draw_heatmap',
'black_background', 'num_instances'
}
postprocess_kwargs: set = {'pred_out_dir'}
postprocess_kwargs: set = {'pred_out_dir', 'return_datasample'}

def __init__(self,
pose2d: Optional[str] = None,
Expand Down
27 changes: 21 additions & 6 deletions mmpose/apis/inferencers/pose2d_inferencer.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# Copyright (c) OpenMMLab. All rights reserved.
import logging
import os
import warnings
from typing import Dict, List, Optional, Sequence, Tuple, Union

import mmcv
import numpy as np
import torch
from mmengine.config import Config, ConfigDict
from mmengine.infer.infer import ModelType
from mmengine.logging import print_log
from mmengine.model import revert_sync_batchnorm
from mmengine.registry import init_default_scope
from mmengine.structures import InstanceData
Expand Down Expand Up @@ -77,7 +78,7 @@ class Pose2DInferencer(BaseMMPoseInferencer):
'draw_heatmap',
'black_background',
}
postprocess_kwargs: set = {'pred_out_dir'}
postprocess_kwargs: set = {'pred_out_dir', 'return_datasample'}

def __init__(self,
model: Union[ModelType, str],
Expand Down Expand Up @@ -183,8 +184,19 @@ def preprocess_single(self,

if self.cfg.data_mode == 'topdown':
if self.detector is not None:
det_results = self.detector(
input, return_datasamples=True)['predictions']
try:
det_results = self.detector(
input, return_datasamples=True)['predictions']
except ValueError:
print_log(
'Support for mmpose and mmdet versions up to 3.1.0 '
'will be discontinued in upcoming releases. To '
'ensure ongoing compatibility, please upgrade to '
'mmdet version 3.2.0 or later.',
logger='current',
level=logging.WARNING)
det_results = self.detector(
input, return_datasample=True)['predictions']
pred_instance = det_results[0].pred_instances.cpu().numpy()
bboxes = np.concatenate(
(pred_instance.bboxes, pred_instance.scores[:, None]),
Expand Down Expand Up @@ -301,8 +313,11 @@ def __call__(
inputs = self._get_webcam_inputs(inputs)
batch_size = 1
if not visualize_kwargs.get('show', False):
warnings.warn('The display mode is closed when using webcam '
'input. It will be turned on automatically.')
print_log(
'The display mode is closed when using webcam '
'input. It will be turned on automatically.',
logger='current',
level=logging.WARNING)
visualize_kwargs['show'] = True
else:
inputs = self._inputs_to_list(inputs)
Expand Down
2 changes: 1 addition & 1 deletion mmpose/apis/inferencers/pose3d_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class Pose3DInferencer(BaseMMPoseInferencer):
'kpt_thr',
'vis_out_dir',
}
postprocess_kwargs: set = {'pred_out_dir'}
postprocess_kwargs: set = {'pred_out_dir', 'return_datasample'}

def __init__(self,
model: Union[ModelType, str],
Expand Down

0 comments on commit aa8ab14

Please sign in to comment.