This is Hello World to demonstrate how to quick start with Intel® Neural Compressor. It is a Keras model on mnist dataset defined by helloworld/train.py, we will implement a customized metric and a customized dataloader for quantization and evaluation.
pip install -r requirements.txt
cd <WORK_DIR>/examples/helloworld
python train.py
The cmd of quantization and predict with the quantized model
python test.py
This example can demonstrate the steps to do quantization on Keras generated saved model with customized dataloader and metric.
class Dataset(object):
def __init__(self):
(train_images, train_labels), (test_images,
test_labels) = keras.datasets.fashion_mnist.load_data()
self.test_images = test_images.astype(np.float32) / 255.0
self.labels = test_labels
pass
def __getitem__(self, index):
return self.test_images[index], self.labels[index]
def __len__(self):
return len(self.test_images)
This customized metric will calculate accuracy.
class MyMetric(object):
def __init__(self, *args):
self.pred_list = []
self.label_list = []
self.samples = 0
pass
def update(self, predict, label):
self.pred_list.extend(np.argmax(predict, axis=1))
self.label_list.extend(label)
self.samples += len(label)
pass
def reset(self):
self.pred_list = []
self.label_list = []
self.samples = 0
pass
def result(self):
correct_num = np.sum(
np.array(self.pred_list) == np.array(self.label_list))
return correct_num / self.samples
dataset = Dataset()
dataloader = DataLoader(framework='tensorflow', dataset=dataset, batch_size=1)
config = PostTrainingQuantConfig()
q_model = fit(
model='../models/saved_model',
conf=config,
calib_dataloader=dataloader,
eval_dataloader=dataloader,
eval_metric=MyMetric())
Please get the input and output op name from nc_workspace/tensorflow/hello_world/deploy.yaml
Run inference on the quantized model
import tensorflow as tf
with tf.compat.v1.Graph().as_default(), tf.compat.v1.Session() as sess:
tf.compat.v1.import_graph_def(q_model.graph_def, name='')
styled_image = sess.run(['output:0'], feed_dict={'input:0':dataset.test_images})
print("Inference is done.")