forked from wingniuqichao/caffe_Unet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
python.py
65 lines (51 loc) · 1.71 KB
/
python.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import cv2
import numpy as np
import time
import argparse
cv2.namedWindow("原始", cv2.WINDOW_NORMAL)
cv2.namedWindow("黑白", cv2.WINDOW_NORMAL)
cv2.namedWindow("标注", cv2.WINDOW_NORMAL)
cv2.namedWindow("扣取", cv2.WINDOW_NORMAL)
# 读取模型
net = cv2.dnn.readNetFromCaffe("unet.prototxt", 'unet.caffemodel')
def predict(frame):
frame = cv2.resize(frame, (256, 256))
cv2.imshow("原始", frame)
inputBlob = cv2.dnn.blobFromImage(frame, 1/255.0, (256, 256), (127.5, 127.5, 127.5), False)
# 预测
net.setInput(inputBlob, 'data')
pred = net.forward("predict")
# 获取结果
pred = pred[0,1,:,:]
pred[pred>0.5] = 255
pred[pred<=0.5] = 0
pred = np.array(pred, dtype=np.uint8)
# 将人像扣取出来
frame_person = frame.copy()
frame_person[pred==0] = [255, 255, 255]
# 将人像用红色标注
frame[:,:,2][pred==255] = 255
t2 = time.time()
cv2.imshow('黑白', pred)
cv2.imshow('标注', frame)
cv2.imshow('扣取', frame_person)
def main():
if args.video:
cap = cv2.VideoCapture(args.video)
while(1):
t1 = time.time()
ret, frame = cap.read()
if not ret:
break
predict(frame)
cv2.waitKey(1)
elif args.image:
img = cv2.imread(args.image, 1)
predict(img)
cv2.waitKey(0)
if __name__ == '__main__':
parse = argparse.ArgumentParser(description='command for predict unet model')
parse.add_argument('--image', type=str, default=None, help='the image to predict')
parse.add_argument('--video', type=str, default='test/cxk.mp4', help='the video to predict')
args = parse.parse_args()
main()