-
Notifications
You must be signed in to change notification settings - Fork 23
/
Fashion_Test.py
140 lines (113 loc) · 4.81 KB
/
Fashion_Test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
"""
Keras RFCN
Copyright (c) 2018
Licensed under the MIT License (see LICENSE for details)
Written by [email protected]
"""
'''
This is a demo to Eval a RFCN model with DeepFashion Dataset
http://mmlab.ie.cuhk.edu.hk/projects/DeepFashion.html
'''
from KerasRFCN.Model.Model import RFCN_Model
from KerasRFCN.Config import Config
import KerasRFCN.Utils
import os
from keras.preprocessing import image
import pickle
import numpy as np
import argparse
import matplotlib.pyplot as plt
import matplotlib.patches as patches
class RFCNNConfig(Config):
"""Configuration for training on the toy shapes dataset.
Derives from the base Config class and overrides values specific
to the toy shapes dataset.
"""
# Give the configuration a recognizable name
NAME = "Fashion"
# Backbone model
# choose one from ['resnet50', 'resnet101', 'resnet50_dilated', 'resnet101_dilated']
BACKBONE = "resnet101"
# Train on 1 GPU and 8 images per GPU. We can put multiple images on each
# GPU because the images are small. Batch size is 8 (GPUs * images/GPU).
GPU_COUNT = 1
IMAGES_PER_GPU = 1
# Number of classes (including background)
C = 1 + 46 # background + 2 tags
NUM_CLASSES = C
# Use small images for faster training. Set the limits of the small side
# the large side, and that determines the image shape.
IMAGE_MIN_DIM = 640
IMAGE_MAX_DIM = 768
# Use smaller anchors because our image and objects are small
RPN_ANCHOR_SCALES = (32, 64, 128, 256, 512) # anchor side in pixels
# Use same strides on stage 4-6 if use dilated resnet of DetNet
# Like BACKBONE_STRIDES = [4, 8, 16, 16, 16]
BACKBONE_STRIDES = [4, 8, 16, 32, 64]
# Reduce training ROIs per image because the images are small and have
# few objects. Aim to allow ROI sampling to pick 33% positive ROIs.
TRAIN_ROIS_PER_IMAGE = 200
# Use a small epoch since the data is simple
STEPS_PER_EPOCH = 100
# use small validation steps since the epoch is small
VALIDATION_STEPS = 5
RPN_NMS_THRESHOLD = 0.7
DETECTION_MIN_CONFIDENCE = 0.4
POOL_SIZE = 7
def Test(model, loadpath, savepath):
assert not loadpath == savepath, "loadpath should'n same with savepath"
model_path = model.find_last()[1]
# Load trained weights (fill in path to trained weights here)
model.load_weights(model_path, by_name=True)
print("Loading weights from ", model_path)
if os.path.isdir(loadpath):
for idx, imgname in enumerate(os.listdir(loadpath)):
if not imgname.lower().endswith(('.bmp', '.jpeg', '.jpg', '.png', '.tif', '.tiff')):
continue
print(imgname)
imageoriChannel = np.array(plt.imread( os.path.join(loadpath, imgname) )) / 255.0
img = image.img_to_array( image.load_img(os.path.join(loadpath, imgname)) )
TestSinglePic(img, imageoriChannel, model, savepath=savepath, imgname=imgname)
elif os.path.isfile(loadpath):
if not loadpath.lower().endswith(('.bmp', '.jpeg', '.jpg', '.png', '.tif', '.tiff')):
print("not image file!")
return
print(loadpath)
imageoriChannel = np.array(plt.imread( loadpath )) / 255.0
img = image.img_to_array( image.load_img(loadpath) )
(filename,extension) = os.path.splitext(loadpath)
TestSinglePic(img, imageoriChannel, model, savepath=savepath, imgname=filename)
def TestSinglePic(image, image_ori, model, savepath, imgname):
r = model.detect([image], verbose=1)[0]
print(r)
def get_ax(rows=1, cols=1, size=8):
_, ax = plt.subplots(rows, cols, figsize=(size*cols, size*rows))
return ax
ax = get_ax(1)
assert not savepath == "", "empty save path"
assert not imgname == "", "empty image file name"
for box in r['rois']:
y1, x1, y2, x2 = box
p = patches.Rectangle((x1, y1), x2 - x1, y2 - y1, linewidth=2,
alpha=0.7, linestyle="dashed",
edgecolor="red", facecolor='none')
ax.add_patch(p)
ax.imshow(image_ori)
plt.savefig(os.path.join(savepath, imgname),bbox_inches='tight')
plt.clf()
if __name__ == '__main__':
ROOT_DIR = os.getcwd()
parser = argparse.ArgumentParser()
parser.add_argument('--loadpath', required=False,
default="images/",
metavar="evaluate images loadpath",
help="evaluate images loadpath")
parser.add_argument('--savepath', required=False,
default="result/",
metavar="evaluate images savepath",
help="evaluate images savepath")
config = RFCNNConfig()
args = parser.parse_args()
model = RFCN_Model(mode="inference", config=config,
model_dir=os.path.join(ROOT_DIR, "logs") )
Test(model, args.loadpath, args.savepath)