Skip to content

Commit

Permalink
Add last-minute modifications to the image-based alexnet classifier d…
Browse files Browse the repository at this point in the history
…one at XPrize finals
  • Loading branch information
paolo-viceconte committed Dec 21, 2022
1 parent dc92728 commit 381fcf4
Showing 1 changed file with 111 additions and 75 deletions.
186 changes: 111 additions & 75 deletions classifier_images_alexnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
norm_threshold = 15

# Training hyperparams
epochs = 100
epochs = 25
batch_size = 32

# Training and testing data
Expand Down Expand Up @@ -385,6 +385,9 @@ def update_plot(i, data, scat):
x_train.append(train_image)
y_train.append(labels[j])

contact = False
contact_images = []
contact_labels = []

if dataset in test_datasets:

Expand All @@ -397,11 +400,14 @@ def update_plot(i, data, scat):
# For the two classes of interest
if labels[j] != -1:

if not contact:
contact = True

test_image = np.zeros((9, 11, 1))
for i in range(len(image_ordered_palm_x)):
test_image[image_ordered_palm_x[i], image_ordered_palm_y[i]] = [data[str(hand) + "_palm"][j][i]/255]
x_test.append(test_image)
y_test.append(labels[j])
contact_images.append(test_image)
contact_labels.append(labels[j])

if visualize_test_set:

Expand All @@ -412,6 +418,13 @@ def update_plot(i, data, scat):
plt.pause(0.3)
plt.close()

elif contact:
contact = False
x_test.append(contact_images)
y_test.append(contact_labels)
contact_images = []
contact_labels = []


# Convert to numpy array
x_train = np.array(x_train)
Expand All @@ -430,24 +443,24 @@ def update_plot(i, data, scat):
# featurewise_center=True,
# featurewise_std_normalization=True,
rotation_range=30,
width_shift_range=0.3,
height_shift_range=0.3,
# horizontal_flip=True,
width_shift_range=0.2,
height_shift_range=0.2,
horizontal_flip=True,
# vertical_flip=True
)
it = datagen.flow(x_train, y_train, batch_size=batch_size)

# Check classes
classes = np.unique(np.concatenate((y_train, y_test), axis=0))
print(classes)
# # Check classes
# classes = np.unique(np.concatenate((y_train, y_test), axis=0))
# print(classes)

##########
# TRAINING
##########

if training:

model = alexnet_model(scale=8)
model = alexnet_model(scale=32)

# shuffle the training set
idx = np.random.permutation(x_train.shape[0])
Expand Down Expand Up @@ -481,84 +494,107 @@ def update_plot(i, data, scat):
plt.show()
plt.close()

# model.save('model_images.h5', include_optimizer=False)
model.save('model_images.h5')
model.save('model_images.h5', include_optimizer=False)
# model.save('model_images.h5')

################
# EVALUATE MODEL
################

model = keras.models.load_model("model_images.h5")
# model = keras.models.load_model("model_images_v0.h5")
# model = keras.models.load_model("model_images.h5")
model = keras.models.load_model("model_alexnet32_v0.h5")

# Check the evolution of the model on the test set
pred_classes = []
contact_wise_data = []
contact_wise_correct = []
correct = 0
all_data = 0
for i in range(len(x_test)):
pred = model(np.reshape(x_test[i], (1, x_test[i].shape[0], x_test[i].shape[1]))).numpy()[0]
pred_class = np.where(pred > 0.5, 1, 0)[0]
pred_classes.append(pred_class)
print(pred_class, " (" + str(round(pred[0], 3)) + ")\t", y_test[i])
test_acc_no_filter = np.count_nonzero(abs(y_test - pred_classes)==0)
print("Accuracy with no filtering: ", round(test_acc_no_filter/len(y_test),2)*100)
pred_classes.append([])
curr_correct = 0
curr_data = 0
for j in range(len(x_test[i])):
pred = model(np.expand_dims(x_test[i][j], axis=0)).numpy()[0]
pred_class = np.where(pred > 0.5, 1, 0)[0]
pred_classes[-1].append(pred_class)
if(pred_class == y_test[i][j]):
curr_correct += 1
curr_data += 1
correct += curr_correct
all_data += curr_data
contact_wise_correct.append(curr_correct)
contact_wise_data.append(curr_data)
print("Accuracy with no filtering: ", round(correct/all_data,2)*100)

# Plot the evolution of the model on the test set
fig = plt.figure()
plt.plot(np.array(range(len(pred_classes))),
pred_classes,
label="Prediction",
color='blue')
plt.fill_between(np.array(range(len(pred_classes))),
0,
abs(y_test - pred_classes),
label="Errors",
color='red')
plt.plot(np.array(range(len(y_test))),
y_test,
label="Ground-truth",
color='black')
plt.xlabel("measurements")
plt.ylabel("label")
plt.grid()
plt.legend()
plt.title("Prediction VS ground truth - test set", fontsize=16)
plt.show()

# Filtering
filtered_pred_class = medfilt(pred_classes, kernel_size=9)
test_acc_filtered = np.count_nonzero(abs(y_test - filtered_pred_class)==0)
print("Accuracy with filtering: ", round(test_acc_filtered/len(y_test),2)*100)

# Plot the evolution of the model on the test set after filtering
fig = plt.figure()
plt.plot(np.array(range(len(filtered_pred_class))),
filtered_pred_class,
label="Filtered Prediction",
color='blue')
plt.fill_between(np.array(range(len(filtered_pred_class))),
0,
abs(y_test - filtered_pred_class),
label="Errors",
color='red')
plt.plot(np.array(range(len(y_test))),
y_test,
label="Ground-truth",
color='black')
plt.xlabel("measurements")
plt.ylabel("label")
plt.grid()
plt.legend()
plt.title("Filtered Prediction VS ground truth - test set", fontsize=16)
fig, axs = plt.subplots(6, 6, figsize=(3, 4))
for i in range(6):
for j in range(6):
index_element = i * 6 + j
if index_element < 37:
axs[i,j].plot(np.array(range(len(y_test[index_element]))),
y_test[index_element],
label="Ground-truth",
color='black')
axs[i,j].plot(np.array(range(len(pred_classes[index_element]))),
pred_classes[index_element],
label="Predictions",
color='blue')
if y_test[index_element][0] == 0:
axs[i,j].fill_between(np.array(range(len(pred_classes[index_element]))),
0,
abs(np.array(y_test[index_element]) - np.array(pred_classes[index_element])),
label="Errors",
color='red')
elif y_test[index_element][0] == 1:
axs[i, j].fill_between(np.array(range(len(pred_classes[index_element]))),
1,
np.array(pred_classes[index_element]),
label="Errors",
color='red')
axs[i,j].tick_params(axis='x', colors='white')
axs[i,j].tick_params(axis='y', colors='white')
axs[i,j].set_title("acc="+str(round(contact_wise_correct[index_element]/contact_wise_data[index_element],2)*100), fontsize=10)
plt.suptitle("Test set accuracy: "+str(round(correct/all_data,2)*100), fontsize=16)
plt.show()

# Check the accuracy on the whole test set
test_loss, test_acc = model.evaluate(x_test, y_test)
print("Test accuracy", test_acc)
print("Test loss", test_loss)

# Confusion matrix
y_test_prob = model.predict(x_test)
y_test_pred = np.where(y_test_prob > 0.5, 1, 0)
print(confusion_matrix(y_test, y_test_pred))
# # Filtering
# filtered_pred_class = medfilt(pred_classes, kernel_size=9)
# test_acc_filtered = np.count_nonzero(abs(y_test - filtered_pred_class)==0)
# print("Accuracy with filtering: ", round(test_acc_filtered/len(y_test),2)*100)
#
# # Plot the evolution of the model on the test set after filtering
# fig = plt.figure()
# plt.plot(np.array(range(len(filtered_pred_class))),
# filtered_pred_class,
# label="Filtered Prediction",
# color='blue')
# plt.fill_between(np.array(range(len(filtered_pred_class))),
# 0,
# abs(y_test - filtered_pred_class),
# label="Errors",
# color='red')
# plt.plot(np.array(range(len(y_test))),
# y_test,
# label="Ground-truth",
# color='black')
# plt.xlabel("measurements")
# plt.ylabel("label")
# plt.grid()
# plt.legend()
# plt.title("Filtered Prediction VS ground truth - test set", fontsize=16)
# plt.show()

# # Check the accuracy on the whole test set
# test_loss, test_acc = model.evaluate(x_test, y_test)
# print("Test accuracy", test_acc)
# print("Test loss", test_loss)
#
# # Confusion matrix
# y_test_prob = model.predict(x_test)
# y_test_pred = np.where(y_test_prob > 0.5, 1, 0)
# print(confusion_matrix(y_test, y_test_pred))

# Debug saved data for the C++ implementation
# import json
Expand Down

0 comments on commit 381fcf4

Please sign in to comment.