Skip to content

Commit

Permalink
add progress bar in bottom-up tracking demo (open-mmlab#1454)
Browse files Browse the repository at this point in the history
  • Loading branch information
ly015 authored Jun 30, 2022
1 parent 5330ae4 commit f938505
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 19 deletions.
30 changes: 12 additions & 18 deletions demo/bottom_up_pose_tracking_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from argparse import ArgumentParser

import cv2
import mmcv

from mmpose.apis import (get_track_id, inference_bottom_up_pose_model,
init_pose_model, vis_pose_tracking_result)
Expand Down Expand Up @@ -86,10 +87,8 @@ def main():
else:
dataset_info = DatasetInfo(dataset_info)

cap = cv2.VideoCapture(args.video_path)
fps = None

assert cap.isOpened(), f'Faild to load video file {args.video_path}'
video = mmcv.VideoReader(args.video_path)
assert video.opened, f'Faild to load video file {args.video_path}'

if args.out_video_root == '':
save_out_video = False
Expand All @@ -98,9 +97,8 @@ def main():
save_out_video = True

if save_out_video:
fps = cap.get(cv2.CAP_PROP_FPS)
size = (int(cap.get(cv2.CAP_PROP_FRAME_WIDTH)),
int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT)))
fps = video.fps
size = (video.width, video.height)
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
videoWriter = cv2.VideoWriter(
os.path.join(args.out_video_root,
Expand Down Expand Up @@ -128,15 +126,12 @@ def main():
output_layer_names = None
next_id = 0
pose_results = []
while (cap.isOpened()):
flag, img = cap.read()
if not flag:
break
for cur_frame in mmcv.track_iter_progress(video):
pose_results_last = pose_results

pose_results, returned_outputs = inference_bottom_up_pose_model(
pose_results, _ = inference_bottom_up_pose_model(
pose_model,
img,
cur_frame,
dataset=dataset,
dataset_info=dataset_info,
pose_nms_thr=args.pose_nms_thr,
Expand All @@ -158,9 +153,9 @@ def main():
pose_results = smoother.smooth(pose_results)

# show the results
vis_img = vis_pose_tracking_result(
vis_frame = vis_pose_tracking_result(
pose_model,
img,
cur_frame,
pose_results,
radius=args.radius,
thickness=args.thickness,
Expand All @@ -170,15 +165,14 @@ def main():
show=False)

if args.show:
cv2.imshow('Image', vis_img)
cv2.imshow('Image', vis_frame)

if save_out_video:
videoWriter.write(vis_img)
videoWriter.write(vis_frame)

if args.show and cv2.waitKey(1) & 0xFF == ord('q'):
break

cap.release()
if save_out_video:
videoWriter.release()
if args.show:
Expand Down
2 changes: 1 addition & 1 deletion demo/top_down_pose_tracking_demo_with_mmdet.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def main():
args.online)

# test a single image, with a list of bboxes.
pose_results, returned_outputs = inference_top_down_pose_model(
pose_results, _ = inference_top_down_pose_model(
pose_model,
frames if args.use_multi_frames else cur_frame,
person_results,
Expand Down

0 comments on commit f938505

Please sign in to comment.