forked from kmeng01/yolo-server
-
Notifications
You must be signed in to change notification settings - Fork 0
/
yolo_backend.py
76 lines (58 loc) · 2.03 KB
/
yolo_backend.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import hashlib
from ultralytics import YOLO
import cv2
import matplotlib
import torch
from skimage import io
import pandas as pd
from ultralytics.nn.modules import DFL
# Define YOLO model ops
# yolo_model = torch.hub.load(
# "ultralytics/yolov5", "yolov5s"
# ) # or yolov5n - yolov5x6, custom
yolo_model = YOLO("/my_vol/yolo5small.pt")
#yolo_model = torch.hub.load('ultralytics/yolov5', 'custom', path='/my_vol/yolo5small.pt',force_reload=True)
cmap = matplotlib.pyplot.get_cmap("jet")
def hash_to_range(number, N):
# Convert the number to a string and then encode it to bytes
byte_representation = str(number).encode("utf-8")
# Use SHA-256 hash function
hashed = hashlib.sha256(byte_representation)
# Convert the hash to an integer
hash_integer = int(hashed.hexdigest(), 16)
# Map the hash to the range [1, N]
return 1 + (hash_integer % N)
names = {'0': 'anger', "1": 'fear', "2": 'happy', "3": 'neutral', "4": 'sad'}
def predict(img_path):
results = yolo_model(img_path)
return results[0]
def predict_and_draw(img, out_img_path):
img = io.imread(img)
img_cv2 = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
result = predict(img_cv2)
result.save(out_img_path)
print(out_img_path)
print(result.tojson())
result = pd.read_json(result.tojson())
result["color"] = result.apply(
lambda x: tuple(
c * 255
for c in cmap(
hash_to_range(x["class"], 5) / 5
)
)[:3],
axis=1,
)
# for _, b in result.iterrows():
# el = b
# p1, p2 = (el["xmin"], el["ymin"]), (el["xmax"], el["ymax"])
# p1, p2 = tuple(map(int, p1)), tuple(map(int, p2))
# cv2.rectangle(img_cv2, p1, p2, el["color"], 2)
#img = io.imread(out_img_path)
#img_cv2 = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
#cv2.imwrite(str(out_img_path), img_cv2)
#cv2.imwrite(str(out_img_path), img_cv2)
result_dict = result.to_dict(orient="index")
return {
"boxes": [v for _, v in result_dict.items()],
}