-
Notifications
You must be signed in to change notification settings - Fork 0
/
collect_data.py
105 lines (82 loc) · 3.3 KB
/
collect_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import os
import cv2
import mediapipe as mp
import numpy as np
import csv
import copy
import argparse
import itertools
mp_drawing = mp.solutions.drawing_utils
mp_drawing_styles = mp.solutions.drawing_styles
mp_hands = mp.solutions.hands
def update_csv(label,landmark_list):
# print(label)
csv_path = 'data.csv'
with open(csv_path, 'a', newline="") as f:
writer = csv.writer(f)
writer.writerow([label, *landmark_list])
return
def calc_landmark_list(image, landmarks):
image_width, image_height = image.shape[1], image.shape[0]
landmark_point = []
landmark_point_mirrorx = []
landmark_point_mirrory = []
# Keypoint
for _, landmark in enumerate(landmarks.landmark):
landmark_x = min(int(landmark.x * image_width), image_width - 1)
landmark_y = min(int(landmark.y * image_height), image_height - 1)
# landmark_z = landmark.z
landmark_point.append([landmark_x, landmark_y])
landmark_point_mirrorx.append([-landmark_x, landmark_y])
landmark_point_mirrory.append([landmark_x, -landmark_y])
return landmark_point, landmark_point_mirrorx, landmark_point_mirrory
def preprocess_landmark(landmark_list):
temp_landmark_list = copy.deepcopy(landmark_list)
# Convert to relative coordinates
base_x, base_y = 0, 0
for index, landmark_point in enumerate(temp_landmark_list):
if index == 0:
base_x, base_y = landmark_point[0], landmark_point[1]
temp_landmark_list[index][0] = temp_landmark_list[index][0] - base_x
temp_landmark_list[index][1] = temp_landmark_list[index][1] - base_y
# Convert to a one-dimensional list
temp_landmark_list = list(
itertools.chain.from_iterable(temp_landmark_list))
# Normalization
max_value = max(list(map(abs, temp_landmark_list)))
def normalize_(n):
return n / max_value
temp_landmark_list = list(map(normalize_, temp_landmark_list))
return temp_landmark_list
# For static images:
IMAGE_FILES = []
classes=['rock','paper','scissors']
with mp_hands.Hands(
static_image_mode=True,
max_num_hands=1,
min_detection_confidence=0.1) as hands:
count=[0]*3
for i,c in enumerate(classes):
directory='/home/dtech/Documents/git/vision-rps/dataset/train/'+c
for filename in os.listdir(directory):
print(count)
file = os.path.join(directory, filename)
# print(file)
image=cv2.imread(file)
results = hands.process(image)
# Draw the hand annotations on the image.
image.flags.writeable = True
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
debug_image = copy.deepcopy(image)
if results.multi_hand_landmarks:
count[i] += 1
for hand_landmarks in results.multi_hand_landmarks:
landmark,lmx,lmy = calc_landmark_list(debug_image, hand_landmarks)
landmark, lmx, lmy=preprocess_landmark(landmark),preprocess_landmark(lmx),preprocess_landmark(lmy)
landmark = np.array(landmark)
lmx = np.array(lmx)
lmy = np.array(lmy)
update_csv(i, landmark)
update_csv(i, lmx)
update_csv(i, lmy)
print(count)