Skip to content

Commit

Permalink
Merge 80562bf into c94434b
Browse files Browse the repository at this point in the history
  • Loading branch information
Ben-Louis authored Jul 24, 2023
2 parents c94434b + 80562bf commit 5bd7338
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 0 deletions.
9 changes: 9 additions & 0 deletions mmpose/apis/inferencers/pose2d_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,15 @@ def __call__(
else:
inputs = self._inputs_to_list(inputs)

# check the compatibility between inputs/outputs
if not self._video_input and len(inputs) > 0:
vis_out_dir = visualize_kwargs.get('vis_out_dir', None)
if vis_out_dir is not None:
_, file_extension = os.path.splitext(vis_out_dir)
assert not file_extension, f'the argument `vis_out_dir` ' \
f'should be a folder while the input contains multiple ' \
f'images, but got {vis_out_dir}'

forward_kwargs['bbox_thr'] = preprocess_kwargs.get('bbox_thr', -1)
inputs = self.preprocess(
inputs, batch_size=batch_size, **preprocess_kwargs)
Expand Down
9 changes: 9 additions & 0 deletions mmpose/apis/inferencers/pose3d_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,15 @@ def __call__(
else:
inputs = self._inputs_to_list(inputs)

# check the compatibility between inputs/outputs
if not self._video_input and len(inputs) > 0:
vis_out_dir = visualize_kwargs.get('vis_out_dir', None)
if vis_out_dir is not None:
_, file_extension = os.path.splitext(vis_out_dir)
assert not file_extension, f'the argument `vis_out_dir` ' \
f'should be a folder while the input contains multiple ' \
f'images, but got {vis_out_dir}'

inputs = self.preprocess(
inputs, batch_size=batch_size, **preprocess_kwargs)

Expand Down
4 changes: 4 additions & 0 deletions tests/test_apis/test_inferencers/test_pose2d_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,10 @@ def test_call(self):
self.assertSequenceEqual(results1['predictions'][0][0]['keypoints'],
results3['predictions'][3][0]['keypoints'])

with self.assertRaises(AssertionError):
for res in inferencer(inputs, vis_out_dir=f'{tmp_dir}/1.jpg'):
pass

# `inputs` is path to a video
inputs = 'tests/data/posetrack18/videos/000001_mpiinew_test/' \
'000001_mpiinew_test.mp4'
Expand Down

0 comments on commit 5bd7338

Please sign in to comment.