-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_cars.py
70 lines (59 loc) · 2.2 KB
/
test_cars.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
# USAGE
# python test_cars.py --checkpoints checkpoints --prefix vggnet --epoch 55
# import the necessary packages
from config import car_config as config
from pyimagesearch.utils.ranked import rank5_accuracy
import mxnet as mx
import argparse
import pickle
import os
# construct the argument parse and parse the arguments
ap = argparse.ArgumentParser()
ap.add_argument("-c", "--checkpoints", required=True,
help="path to output checkpoint directory")
ap.add_argument("-p", "--prefix", required=True,
help="name of model prefix")
ap.add_argument("-e", "--epoch", type=int, required=True,
help="epoch # to load")
args = vars(ap.parse_args())
# load the label encoder
le = pickle.loads(open(config.LABEL_ENCODER_PATH, "rb").read())
# construct the validation image iterator
testIter = mx.io.ImageRecordIter(
path_imgrec=config.TEST_MX_REC,
data_shape=(3, 224, 224),
batch_size=config.BATCH_SIZE,
mean_r=config.R_MEAN,
mean_g=config.G_MEAN,
mean_b=config.B_MEAN)
# load our pre-trained model
print("[INFO] loading pre-trained model...")
checkpointsPath = os.path.sep.join([args["checkpoints"],
args["prefix"]])
(symbol, argParams, auxParams) = mx.model.load_checkpoint(
checkpointsPath, args["epoch"])
# construct the model
model = mx.mod.Module(symbol=symbol, context=[mx.gpu(0)])
model.bind(data_shapes=testIter.provide_data,
label_shapes=testIter.provide_label)
model.set_params(argParams, auxParams)
# initialize the list of predictions and targets
print("[INFO] evaluating model...")
predictions = []
targets = []
# loop over the predictions in batches
for (preds, _, batch) in model.iter_predict(testIter):
# convert the batch of predictions and labels to NumPy
# arrays
preds = preds[0].asnumpy()
labels = batch.label[0].asnumpy().astype("int")
# update the predictions and targets lists, respectively
predictions.extend(preds)
targets.extend(labels)
# apply array slicing to the targets since mxnet will return the
# next full batch size rather than the *actual* number of labels
targets = targets[:len(predictions)]
# compute the rank-1 and rank-5 accuracies
(rank1, rank5) = rank5_accuracy(predictions, targets)
print("[INFO] rank-1: {:.2f}%".format(rank1 * 100))
print("[INFO] rank-5: {:.2f}%".format(rank5 * 100))