Skip to content

Commit

Permalink
add some extra comments
Browse files Browse the repository at this point in the history
  • Loading branch information
mycprotein committed Sep 6, 2022
1 parent ad515c3 commit 2d28961
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions python/nano/tutorial/inference/tensorflow/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from keras import layers
from keras import losses, metrics

# keras.Model is injected with customized functions
from bigdl.nano.tf.keras import Model

# Model / data parameters
Expand Down Expand Up @@ -75,7 +76,8 @@
]
)

model = Model(inputs=model.inputs, outputs=model.outputs)
# this line is optional
# model = Model(inputs=model.inputs, outputs=model.outputs)

model.summary()

Expand All @@ -102,16 +104,22 @@
# Execute quantization
q_model = model.quantize(calib_dataset=tune_dataset)

# Inference using quantized model
y_test_hat = q_model(x_test)

# Evaluate the quantized model
loss = float(tf.reduce_mean(
losses.categorical_crossentropy(y_test, y_test_hat)))
categorical_accuracy = metrics.CategoricalAccuracy()
categorical_accuracy.update_state(y_test, y_test_hat)
accuracy = categorical_accuracy.result().numpy()

print("Quantization test loss:", loss)
print("Quantization test accuracy:", accuracy)
# Raw model test loss: 0.024767747148871422
# Raw model test accuracy: 0.9918000102043152
# Quantized model test loss: 0.02494174614548683
# Quantized model test accuracy: 0.9917
# Accuracy loss: 0.01%
# Accuracy loss: about 0.1% in this case
# Note: accuracy loss varies from different tasks and situations,
# but you can set a quantization threshold when making a quantization model.

0 comments on commit 2d28961

Please sign in to comment.