forked from rwightman/posenet-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathimage_demo.py
73 lines (55 loc) · 2.49 KB
/
image_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import cv2
import time
import argparse
import os
import torch
import posenet
parser = argparse.ArgumentParser()
parser.add_argument('--model', type=int, default=101)
parser.add_argument('--scale_factor', type=float, default=1.0)
parser.add_argument('--notxt', action='store_true')
parser.add_argument('--image_dir', type=str, default='./images')
parser.add_argument('--output_dir', type=str, default='./output')
args = parser.parse_args()
def main():
model = posenet.load_model(args.model)
model = model.cuda()
output_stride = model.output_stride
if args.output_dir:
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
filenames = [
f.path for f in os.scandir(args.image_dir) if f.is_file() and f.path.endswith(('.png', '.jpg'))]
start = time.time()
for f in filenames:
input_image, draw_image, output_scale = posenet.read_imgfile(
f, scale_factor=args.scale_factor, output_stride=output_stride)
with torch.no_grad():
input_image = torch.Tensor(input_image).cuda()
heatmaps_result, offsets_result, displacement_fwd_result, displacement_bwd_result = model(input_image)
pose_scores, keypoint_scores, keypoint_coords = posenet.decode_multiple_poses(
heatmaps_result.squeeze(0),
offsets_result.squeeze(0),
displacement_fwd_result.squeeze(0),
displacement_bwd_result.squeeze(0),
output_stride=output_stride,
max_pose_detections=10,
min_pose_score=0.25)
keypoint_coords *= output_scale
if args.output_dir:
draw_image = posenet.draw_skel_and_kp(
draw_image, pose_scores, keypoint_scores, keypoint_coords,
min_pose_score=0.25, min_part_score=0.25)
cv2.imwrite(os.path.join(args.output_dir, os.path.relpath(f, args.image_dir)), draw_image)
if not args.notxt:
print()
print("Results for image: %s" % f)
for pi in range(len(pose_scores)):
if pose_scores[pi] == 0.:
break
print('Pose #%d, score = %f' % (pi, pose_scores[pi]))
for ki, (s, c) in enumerate(zip(keypoint_scores[pi, :], keypoint_coords[pi, :, :])):
print('Keypoint %s, score = %f, coord = %s' % (posenet.PART_NAMES[ki], s, c))
print('Average FPS:', len(filenames) / (time.time() - start))
if __name__ == "__main__":
main()