diff --git a/mmpose/apis/inferencers/pose2d_inferencer.py b/mmpose/apis/inferencers/pose2d_inferencer.py index 3f1f20fdc0..99b079d529 100644 --- a/mmpose/apis/inferencers/pose2d_inferencer.py +++ b/mmpose/apis/inferencers/pose2d_inferencer.py @@ -307,6 +307,15 @@ def __call__( else: inputs = self._inputs_to_list(inputs) + # check the compatibility between inputs/outputs + if not self._video_input and len(inputs) > 0: + vis_out_dir = visualize_kwargs.get('vis_out_dir', None) + if vis_out_dir is not None: + _, file_extension = os.path.splitext(vis_out_dir) + assert not file_extension, f'the argument `vis_out_dir` ' \ + f'should be a folder while the input contains multiple ' \ + f'images, but got {vis_out_dir}' + forward_kwargs['bbox_thr'] = preprocess_kwargs.get('bbox_thr', -1) inputs = self.preprocess( inputs, batch_size=batch_size, **preprocess_kwargs) diff --git a/mmpose/apis/inferencers/pose3d_inferencer.py b/mmpose/apis/inferencers/pose3d_inferencer.py index 0ab7d2e64e..472f43bee2 100644 --- a/mmpose/apis/inferencers/pose3d_inferencer.py +++ b/mmpose/apis/inferencers/pose3d_inferencer.py @@ -392,6 +392,15 @@ def __call__( else: inputs = self._inputs_to_list(inputs) + # check the compatibility between inputs/outputs + if not self._video_input and len(inputs) > 0: + vis_out_dir = visualize_kwargs.get('vis_out_dir', None) + if vis_out_dir is not None: + _, file_extension = os.path.splitext(vis_out_dir) + assert not file_extension, f'the argument `vis_out_dir` ' \ + f'should be a folder while the input contains multiple ' \ + f'images, but got {vis_out_dir}' + inputs = self.preprocess( inputs, batch_size=batch_size, **preprocess_kwargs) diff --git a/tests/test_apis/test_inferencers/test_pose2d_inferencer.py b/tests/test_apis/test_inferencers/test_pose2d_inferencer.py index b59232efac..be00527ff1 100644 --- a/tests/test_apis/test_inferencers/test_pose2d_inferencer.py +++ b/tests/test_apis/test_inferencers/test_pose2d_inferencer.py @@ -144,6 +144,10 @@ def test_call(self): self.assertSequenceEqual(results1['predictions'][0][0]['keypoints'], results3['predictions'][3][0]['keypoints']) + with self.assertRaises(AssertionError): + for res in inferencer(inputs, vis_out_dir=f'{tmp_dir}/1.jpg'): + pass + # `inputs` is path to a video inputs = 'tests/data/posetrack18/videos/000001_mpiinew_test/' \ '000001_mpiinew_test.mp4'