-
Notifications
You must be signed in to change notification settings - Fork 12
/
get_cropped_TigDog.py
132 lines (107 loc) · 4.59 KB
/
get_cropped_TigDog.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
from __future__ import print_function, absolute_import
import os
import numpy as np
import json
import random
import math
import torch
import torch.utils.data as data
from pose.utils.osutils import *
from pose.utils.imutils import *
from pose.utils.transforms import *
from pose.utils.evaluation import final_preds
import pose.models as models
import glob
import cv2
from tqdm import tqdm
import scipy.misc
import scipy.ndimage
from scipy.io import loadmat
import imageio
def load_animal(data_dir='./', animal='horse'):
"""
Output:
img_list: Nx3 # each image is associated with a shot-id and a shot-id frame_id,
# e.g. ('***.jpg', 100, 2) means the second frame in shot 100.
anno_list: Nx3 # (x, y, visiblity)
"""
range_path = os.path.join(data_dir, 'behaviorDiscovery2.0/ranges', animal, 'ranges.mat')
landmark_path = os.path.join(data_dir, 'behaviorDiscovery2.0/landmarks', animal)
img_list = [] # img_list contains all image paths
anno_list = [] # anno_list contains all anno lists
range_file = loadmat(range_path)
for video in range_file['ranges']:
# range_file['ranges'] is a numpy array [Nx3]: shot_id, start_frame, end_frame
shot_id = video[0]
landmark_path_video = os.path.join(landmark_path, str(shot_id)+'.mat')
if not os.path.isfile(landmark_path_video):
continue
landmark_file = loadmat(landmark_path_video)
for frame in range(video[1], video[2]+1): # ??? video[2]+1
frame_id = frame - video[1]
img_name = '0'*(8-len(str(frame))) + str(frame) + '.jpg'
img_list.append([img_name, shot_id, frame_id])
coord = landmark_file['landmarks'][frame_id][0][0][0][0]
vis = landmark_file['landmarks'][frame_id][0][0][0][1]
landmark = np.hstack((coord, vis))
anno_list.append(landmark[:18,:])
return img_list, anno_list
def dataset_filter(anno_list):
"""
output:
idxs: valid_idxs after filtering
"""
num_kpts = anno_list[0].shape[0]
idxs = []
for i in range(len(anno_list)):
s = sum(anno_list[i][:,2])
if s>num_kpts//2:
idxs.append(i)
return idxs
def im_to_torch(img):
img = np.transpose(img, (2, 0, 1)) # C*H*W
img = to_torch(img).float()
if img.max() > 1:
img /= 255
return img
def get_cropped_dataset(img_folder, img_list, anno_list, img_idxs, animal):
count = 0
for i in tqdm(range(len(img_list))):
img = scipy.misc.imread(os.path.join(img_folder, 'behaviorDiscovery2.0/', animal, img_list[i][0]), mode='RGB')
img_new_path = 'crop_'+img_list[i][0]
frame = img.copy()
img = im_to_torch(img)
# get correct scale and center
if i in img_idxs:
x_min = float(np.min(anno_list[img_idxs[count]][:,0] \
[anno_list[img_idxs[count]][:,0]>0]))
x_max = float(np.max(anno_list[img_idxs[count]][:,0] \
[anno_list[img_idxs[count]][:,0]>0]))
y_min = float(np.min(anno_list[img_idxs[count]][:,1] \
[anno_list[img_idxs[count]][:,1]>0]))
y_max = float(np.max(anno_list[img_idxs[count]][:,1] \
[anno_list[img_idxs[count]][:,1]>0]))
c = torch.Tensor(( (x_min+x_max)/2.0, (y_min+y_max)/2.0 ))
s = max(x_max-x_min, y_max-y_min)/200.0 * 1.5
rot = 0
inp = crop(img, c, s, [256, 256], rot)
frame = torch.Tensor(frame.transpose(2,0,1))
frame = crop(frame, c, s, [256, 256], rot)
frame = (frame.numpy().transpose(1,2,0))*255
frame = np.uint8(frame)
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
count += 1
imageio.imwrite('./animal_data/real_animal_crop/real_' + animal + '_crop/'+img_new_path, frame)
print('number of cropped '+animal+': ', count)
return None
if __name__== "__main__":
animals = ['horse', 'tiger']
img_folder = './animal_data/'
for animal in animals:
img_list, anno_list = load_animal(data_dir=img_folder, animal=animal)
img_idxs = dataset_filter(anno_list)
print(len(img_idxs))
if not os.path.exists(os.path.join(img_folder, 'real_animal_crop', 'real_'+animal+'_crop')):
os.makedirs(os.path.join(img_folder, 'real_animal_crop', 'real_'+animal+'_crop'))
get_cropped_dataset(img_folder, img_list, anno_list, img_idxs, animal)