Skip to content

Commit

Permalink
add a check for input/output type
Browse files Browse the repository at this point in the history
  • Loading branch information
Ben-Louis committed Jul 24, 2023
1 parent c94434b commit 19c397e
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 0 deletions.
9 changes: 9 additions & 0 deletions mmpose/apis/inferencers/pose2d_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,15 @@ def __call__(
else:
inputs = self._inputs_to_list(inputs)

# check the compatibility between inputs/outputs
if not self._video_input and len(inputs) > 0:
vis_out_dir = visualize_kwargs.get('vis_out_dir', None)
if vis_out_dir is not None:
_, file_extension = os.path.splitext(vis_out_dir)
assert not file_extension, f'the argument `vis_out_dir` ' \
f'should be a folder while the input contains multiple ' \
f'images, but got {vis_out_dir}'

forward_kwargs['bbox_thr'] = preprocess_kwargs.get('bbox_thr', -1)
inputs = self.preprocess(
inputs, batch_size=batch_size, **preprocess_kwargs)
Expand Down
9 changes: 9 additions & 0 deletions mmpose/apis/inferencers/pose3d_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,15 @@ def __call__(
else:
inputs = self._inputs_to_list(inputs)

# check the compatibility between inputs/outputs
if not self._video_input and len(inputs) > 0:
vis_out_dir = visualize_kwargs.get('vis_out_dir', None)
if vis_out_dir is not None:
_, file_extension = os.path.splitext(vis_out_dir)
assert not file_extension, f'the argument `vis_out_dir` ' \
f'should be a folder while the input contains multiple ' \
f'images, but got {vis_out_dir}'

inputs = self.preprocess(
inputs, batch_size=batch_size, **preprocess_kwargs)

Expand Down

0 comments on commit 19c397e

Please sign in to comment.