diff --git a/examples/.config/model_params_tensorflow.json b/examples/.config/model_params_tensorflow.json
index 6051e597444..dee09cf20d1 100644
--- a/examples/.config/model_params_tensorflow.json
+++ b/examples/.config/model_params_tensorflow.json
@@ -1800,6 +1800,13 @@
"input_model": "/tf_dataset2/models/tensorflow/facebook-opt-125m",
"main_script": "main.py",
"batch_size": 16
+ },
+ "ViT": {
+ "model_src_dir": "image_recognition/tensorflow_models/vision_transformer/quantization/ptq",
+ "dataset_location": "/tf_dataset/dataset/imagenet",
+ "input_model": "/tf_dataset/tensorflow/vit/HF-ViT-Base16-Img224-frozen.pb",
+ "main_script": "main.py",
+ "batch_size": 32
}
}
}
diff --git a/examples/README.md b/examples/README.md
index 07d5bbd94b3..f0cf367421a 100644
--- a/examples/README.md
+++ b/examples/README.md
@@ -270,6 +270,12 @@ IntelĀ® Neural Compressor validated examples with multiple compression technique
Post-Training Static Quantization |
pb (smooth quant) |
+
+ ViT |
+ Image Recognition |
+ Post-Training Static Quantization |
+ pb |
+
diff --git a/examples/tensorflow/image_recognition/tensorflow_models/vision_transformer/quantization/ptq/README.md b/examples/tensorflow/image_recognition/tensorflow_models/vision_transformer/quantization/ptq/README.md
new file mode 100644
index 00000000000..40c8987c4e5
--- /dev/null
+++ b/examples/tensorflow/image_recognition/tensorflow_models/vision_transformer/quantization/ptq/README.md
@@ -0,0 +1,62 @@
+Step-by-Step
+============
+
+This document list steps of reproducing Vision Transformer model tuning results via Neural Compressor.
+
+# Prerequisite
+
+## 1. Environment
+
+### Install Dependency Package
+
+```
+pip install -r requirements.txt
+```
+
+### Install Intel Extension for Tensorflow
+
+```shell
+pip install --upgrade intel-extension-for-tensorflow[cpu]
+```
+
+## 2. Prepare Pretrained model
+
+```
+wget https://storage.googleapis.com/intel-optimized-tensorflow/models/2_11_0/HF-ViT-Base16-Img224-frozen.pb
+```
+
+## 3. Prepare Dataset
+
+ TensorFlow [models](https://github.com/tensorflow/models) repo provides [scripts and instructions](https://github.com/tensorflow/models/tree/master/research/slim#an-automated-script-for-processing-imagenet-data) to download, process and convert the ImageNet dataset to the TF records format.
+ We also prepared related scripts in ` examples/tensorflow/image_recognition/tensorflow_models/imagenet_prepare` directory. To download the raw images, the user must create an account with image-net.org. If you have downloaded the raw data and preprocessed the validation data by moving the images into the appropriate sub-directory based on the label (synset) of the image. we can use below command ro convert it to tf records format.
+
+ ```shell
+ cd examples/tensorflow/image_recognition/tensorflow_models/
+ # convert validation subset
+ bash prepare_dataset.sh --output_dir=./vision_transformer/quantization/ptq/data --raw_dir=/PATH/TO/img_raw/val/ --subset=validation
+ # convert train subset
+ bash prepare_dataset.sh --output_dir=./vision_transformer/quantization/ptq/data --raw_dir=/PATH/TO/img_raw/train/ --subset=train
+ ```
+
+# Run
+
+## 1. Quantization
+
+```
+bash run_tuning.sh --input_model --output_model ./output --dataset_location
+```
+
+
+## 2. Benchmark
+
+### Benchmark the fp32 model
+
+```
+bash run_benchmark.sh --input_model= --mode=accuracy --dataset_location= --batch_size=32
+```
+
+### Benchmark the int8 model
+
+```
+bash run_benchmark.sh --input_model=./output.pb --mode=accuracy --dataset_location= --batch_size=32 --int8=true
+```
\ No newline at end of file
diff --git a/examples/tensorflow/image_recognition/tensorflow_models/vision_transformer/quantization/ptq/main.py b/examples/tensorflow/image_recognition/tensorflow_models/vision_transformer/quantization/ptq/main.py
new file mode 100644
index 00000000000..6d79a277599
--- /dev/null
+++ b/examples/tensorflow/image_recognition/tensorflow_models/vision_transformer/quantization/ptq/main.py
@@ -0,0 +1,197 @@
+#
+# -*- coding: utf-8 -*-
+#
+# Copyright (c) 2023 Intel Corporation
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import time
+from argparse import ArgumentParser
+
+import tensorflow as tf
+from tensorflow.python.tools.optimize_for_inference_lib import optimize_for_inference
+from tensorflow.python.framework import dtypes
+from tensorflow.core.protobuf import saved_model_pb2
+
+import numpy as np
+
+INPUTS = 'inputs'
+OUTPUTS = 'Identity'
+
+RESNET_IMAGE_SIZE = 224
+IMAGENET_VALIDATION_IMAGES = 50000
+
+tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)
+
+arg_parser = ArgumentParser(description='Parse args')
+arg_parser.add_argument('-g', "--input-graph",
+ help='Specify the input graph for the transform tool',
+ dest='input_graph')
+arg_parser.add_argument("--output-graph",
+ help='Specify tune result model save dir',
+ dest='output_graph')
+arg_parser.add_argument('--benchmark', dest='benchmark', action='store_true', help='run benchmark')
+arg_parser.add_argument('--mode', dest='mode', default='performance', help='benchmark mode')
+arg_parser.add_argument('--tune', dest='tune', action='store_true', help='use neural_compressor to tune.')
+arg_parser.add_argument('--diagnose', dest='diagnose', action='store_true', help='use Neural Insights to diagnose tuning and benchmark.')
+arg_parser.add_argument('--dataset_location', dest='dataset_location',
+ help='location of calibration dataset and evaluate dataset')
+arg_parser.add_argument('--batch_size', type=int, default=32, dest='batch_size', help='batch_size of benchmark')
+arg_parser.add_argument('--iters', type=int, default=100, dest='iters', help='interations')
+arg_parser.add_argument('--int8', dest='int8', action='store_true', help='whether to use int8 model for benchmark')
+args = arg_parser.parse_args()
+
+def evaluate(model, eval_dataloader, metric, postprocess=None):
+ """Custom evaluate function to estimate the accuracy of the model.
+
+ Args:
+ model (tf.Graph_def): The input model graph
+
+ Returns:
+ accuracy (float): evaluation result, the larger is better.
+ """
+ from neural_compressor.model import Model
+ model = Model(model)
+ input_tensor = model.input_tensor
+ output_tensor = model.output_tensor if len(model.output_tensor)>1 else \
+ model.output_tensor[0]
+ iteration = -1
+ if args.benchmark and args.mode == 'performance':
+ iteration = args.iters
+
+ def eval_func(dataloader):
+ latency_list = []
+ for idx, (inputs, labels) in enumerate(dataloader):
+ # shift the label and rescale the inputs
+ inputs, labels = postprocess((inputs, labels))
+ # dataloader should keep the order and len of inputs same with input_tensor
+ inputs = np.array([inputs])
+ feed_dict = dict(zip(input_tensor, inputs))
+
+ start = time.time()
+ predictions = model.sess.run(output_tensor, feed_dict)
+ end = time.time()
+
+ if isinstance(predictions, list):
+ if len(model.output_tensor_names) == 1:
+ predictions = predictions[0]
+ elif len(model.output_tensor_names) > 1:
+ predictions = predictions[1]
+ metric.update(predictions, labels)
+ latency_list.append(end-start)
+ if idx + 1 == iteration:
+ break
+ latency = np.array(latency_list).mean() / args.batch_size
+ return latency
+
+ latency = eval_func(eval_dataloader)
+ if args.benchmark and args.mode == 'performance':
+ print("Batch size = {}".format(args.batch_size))
+ print("Latency: {:.3f} ms".format(latency * 1000))
+ print("Throughput: {:.3f} images/sec".format(1. / latency))
+ acc = metric.result()
+ return acc
+
+class eval_classifier_optimized_graph:
+ """Evaluate image classifier with optimized TensorFlow graph."""
+
+ def run(self):
+ """This is neural_compressor function include tuning, export and benchmark option."""
+ from neural_compressor import set_random_seed
+ set_random_seed(9527)
+
+ if args.tune:
+ from neural_compressor import quantization
+ from neural_compressor.config import PostTrainingQuantConfig, AccuracyCriterion
+ from neural_compressor.utils.create_obj_from_config import create_dataloader
+ calib_dataloader_args = {
+ 'batch_size': 10,
+ 'dataset': {"ImageRecord": {'root':args.dataset_location}},
+ 'transform': {'ResizeCropImagenet':{'height': 224, 'width': 224},
+ 'TransposeLastChannel':{}},
+ 'filter': None
+ }
+ calib_dataloader = create_dataloader('tensorflow', calib_dataloader_args)
+ eval_dataloader_args = {
+ 'batch_size': 32,
+ 'dataset': {"ImageRecord": {'root':args.dataset_location}},
+ 'transform': {'ResizeCropImagenet': {'height': 224, 'width': 224},
+ 'TransposeLastChannel':{}},
+ 'filter': None
+ }
+ eval_dataloader = create_dataloader('tensorflow', eval_dataloader_args)
+
+ conf = PostTrainingQuantConfig(calibration_sampling_size=[50, 100],
+ accuracy_criterion = AccuracyCriterion(tolerable_loss=0.01),
+ op_type_dict={'conv2d':{ 'weight':{'dtype':['fp32']}, 'activation':{'dtype':['fp32']} }},
+ backend='itex')
+ from neural_compressor import METRICS
+ metrics = METRICS('tensorflow')
+ top1 = metrics['topk']()
+ from tensorflow.core.protobuf import saved_model_pb2
+ sm = saved_model_pb2.SavedModel()
+ with tf.io.gfile.GFile(args.input_graph, "rb") as f:
+ sm.ParseFromString(f.read())
+ graph_def = sm.meta_graphs[0].graph_def
+ from neural_compressor.data import TensorflowShiftRescale
+ postprocess = TensorflowShiftRescale()
+ def eval(model):
+ return evaluate(model, eval_dataloader, top1, postprocess)
+ q_model = quantization.fit(graph_def, conf=conf, calib_dataloader=calib_dataloader,
+ # eval_dataloader=eval_dataloader, eval_metric=top1)
+ eval_func=eval)
+ q_model.save(args.output_graph)
+
+ if args.benchmark:
+ from neural_compressor.utils.create_obj_from_config import create_dataloader
+ dataloader_args = {
+ 'batch_size': args.batch_size,
+ 'dataset': {"ImageRecord": {'root':args.dataset_location}},
+ 'transform': {'ResizeCropImagenet': {'height': 224, 'width': 224},
+ 'TransposeLastChannel':{}},
+ 'filter': None
+ }
+ dataloader = create_dataloader('tensorflow', dataloader_args)
+ from neural_compressor import METRICS
+ metrics = METRICS('tensorflow')
+ top1 = metrics['topk']()
+
+ if args.int8 or args.input_graph.endswith("-tune.pb"):
+ input_graph = args.input_graph
+ else:
+ from tensorflow.core.protobuf import saved_model_pb2
+ sm = saved_model_pb2.SavedModel()
+ with tf.io.gfile.GFile(args.input_graph, "rb") as f:
+ sm.ParseFromString(f.read())
+ graph_def = sm.meta_graphs[0].graph_def
+ input_graph = graph_def
+
+ from neural_compressor.data import TensorflowShiftRescale
+ postprocess = TensorflowShiftRescale()
+ def eval(model):
+ return evaluate(model, dataloader, top1, postprocess)
+
+ if args.mode == 'performance':
+ from neural_compressor.benchmark import fit
+ from neural_compressor.config import BenchmarkConfig
+ conf = BenchmarkConfig(warmup=10, iteration=100, cores_per_instance=4, num_of_instance=1, backend='itex')
+ fit(input_graph, conf, b_dataloader=dataloader)
+ elif args.mode == 'accuracy':
+ acc_result = eval(input_graph)
+ print("Batch size = %d" % dataloader.batch_size)
+ print("Accuracy: %.5f" % acc_result)
+
+if __name__ == "__main__":
+ evaluate_opt_graph = eval_classifier_optimized_graph()
+ evaluate_opt_graph.run()
\ No newline at end of file
diff --git a/examples/tensorflow/image_recognition/tensorflow_models/vision_transformer/quantization/ptq/requirements.txt b/examples/tensorflow/image_recognition/tensorflow_models/vision_transformer/quantization/ptq/requirements.txt
new file mode 100644
index 00000000000..c8d21e74265
--- /dev/null
+++ b/examples/tensorflow/image_recognition/tensorflow_models/vision_transformer/quantization/ptq/requirements.txt
@@ -0,0 +1,2 @@
+tensorflow==2.11.0
+neural-compressor
\ No newline at end of file
diff --git a/examples/tensorflow/image_recognition/tensorflow_models/vision_transformer/quantization/ptq/run_benchmark.sh b/examples/tensorflow/image_recognition/tensorflow_models/vision_transformer/quantization/ptq/run_benchmark.sh
new file mode 100644
index 00000000000..2348865d66e
--- /dev/null
+++ b/examples/tensorflow/image_recognition/tensorflow_models/vision_transformer/quantization/ptq/run_benchmark.sh
@@ -0,0 +1,57 @@
+#!/bin/bash
+set -x
+
+function main {
+
+ init_params "$@"
+ run_benchmark
+
+}
+
+# init params
+function init_params {
+ batch_size=32
+ iters=100
+
+ for var in "$@"
+ do
+ case $var in
+ --input_model=*)
+ input_model=$(echo $var |cut -f2 -d=)
+ ;;
+ --mode=*)
+ mode=$(echo $var |cut -f2 -d=)
+ ;;
+ --dataset_location=*)
+ dataset_location=$(echo $var |cut -f2 -d=)
+ ;;
+ --batch_size=*)
+ batch_size=$(echo $var |cut -f2 -d=)
+ ;;
+ --iters=*)
+ iters=$(echo $var |cut -f2 -d=)
+ ;;
+ --int8=*)
+ int8=$(echo $var |cut -f2 -d=)
+ ;;
+ esac
+ done
+
+}
+
+# run_tuning
+function run_benchmark {
+ if [[ ${int8} == "true" ]]; then
+ extra_cmd=$extra_cmd" --int8"
+ fi
+ python main.py \
+ --input-graph ${input_model} \
+ --mode ${mode} \
+ --dataset_location ${dataset_location} \
+ --batch_size ${batch_size} \
+ --benchmark \
+ --iters ${iters} \
+ ${extra_cmd}
+}
+
+main "$@"
diff --git a/examples/tensorflow/image_recognition/tensorflow_models/vision_transformer/quantization/ptq/run_tuning.sh b/examples/tensorflow/image_recognition/tensorflow_models/vision_transformer/quantization/ptq/run_tuning.sh
new file mode 100644
index 00000000000..6a9e1b859c9
--- /dev/null
+++ b/examples/tensorflow/image_recognition/tensorflow_models/vision_transformer/quantization/ptq/run_tuning.sh
@@ -0,0 +1,39 @@
+#!/bin/bash
+set -x
+
+function main {
+ init_params "$@"
+ run_tuning
+
+}
+
+# init params
+function init_params {
+
+ for var in "$@"
+ do
+ case $var in
+ --input_model=*)
+ input_model=$(echo $var |cut -f2 -d=)
+ ;;
+ --output_model=*)
+ output_model=$(echo $var |cut -f2 -d=)
+ ;;
+ --dataset_location=*)
+ dataset_location=$(echo $var |cut -f2 -d=)
+ ;;
+ esac
+ done
+
+}
+
+# run_tuning
+function run_tuning {
+ python main.py \
+ --input-graph ${input_model} \
+ --output-graph ${output_model} \
+ --dataset_location ${dataset_location} \
+ --tune
+}
+
+main "$@"
diff --git a/neural_compressor/data/__init__.py b/neural_compressor/data/__init__.py
index 4aaaa57c0e5..68fbe546f49 100644
--- a/neural_compressor/data/__init__.py
+++ b/neural_compressor/data/__init__.py
@@ -28,7 +28,7 @@
from .transforms import LabelShift, BilinearImagenetTransform, TensorflowResizeCropImagenetTransform
from .transforms import TFSquadV1PostTransform, TFSquadV1ModelZooPostTransform
from .transforms import TensorflowResizeWithRatio, ResizeTFTransform, RescaleTFTransform, NormalizeTFTransform
-from .transforms import ParseDecodeCocoTransform
+from .transforms import ParseDecodeCocoTransform, TensorflowShiftRescale
from .filters import FILTERS, Filter, filter_registry, LabelBalanceCOCORecordFilter
@@ -51,6 +51,7 @@
'LabelShift',
"ResizeTFTransform",
"RescaleTFTransform",
+ 'TensorflowShiftRescale',
"NormalizeTFTransform",
"ParseDecodeCocoTransform",
'BilinearImagenetTransform',
diff --git a/neural_compressor/data/transforms/__init__.py b/neural_compressor/data/transforms/__init__.py
index cda17595119..baf1581fc6f 100644
--- a/neural_compressor/data/transforms/__init__.py
+++ b/neural_compressor/data/transforms/__init__.py
@@ -23,6 +23,7 @@
from .coco_transform import ParseDecodeCocoTransform
from .postprocess import Postprocess
from .imagenet_transform import LabelShift, BilinearImagenetTransform, TensorflowResizeCropImagenetTransform
+from .imagenet_transform import TensorflowShiftRescale
from os.path import dirname, basename, isfile, join
import glob
@@ -36,4 +37,5 @@
__all__ = ["TRANSFORMS", "BaseTransform", "ComposeTransform", "transform_registry", "ResizeTFTransform",
"Postprocess", "LabelShift", "BilinearImagenetTransform", "TensorflowResizeCropImagenetTransform",
"RescaleTFTransform", "NormalizeTFTransform", "ParseDecodeCocoTransform",
- "TensorflowResizeWithRatio", "TFSquadV1PostTransform", "TFSquadV1ModelZooPostTransform"]
+ "TensorflowResizeWithRatio", "TFSquadV1PostTransform", "TFSquadV1ModelZooPostTransform",
+ "TensorflowShiftRescale"]
diff --git a/neural_compressor/data/transforms/imagenet_transform.py b/neural_compressor/data/transforms/imagenet_transform.py
index 251e1aa667b..287bcd677dc 100644
--- a/neural_compressor/data/transforms/imagenet_transform.py
+++ b/neural_compressor/data/transforms/imagenet_transform.py
@@ -150,6 +150,31 @@ def __call__(self, sample):
"imagenet decoding will be performed automatically from Neural Compressor v1.4.")
return sample
+@transform_registry(transform_type="TransposeLastChannel", process="preprocess", framework="tensorflow")
+class TensorflowTransposeLastChannel(BaseTransform):
+ """Transpose NHWC to NCHW
+
+ Returns:
+ tuple of processed image and label
+ """
+ def __call__(self, sample):
+ image, label = sample
+ image = tf.transpose(image, perm=[2,0,1])
+ return (image, label)
+
+@transform_registry(transform_type="ShiftRescale", process="postprocess", framework="tensorflow")
+class TensorflowShiftRescale(BaseTransform):
+ """label shift by 1 and rescale
+
+ Returns:
+ tuple of processed image and label
+ """
+ def __call__(self, sample):
+ image, label = sample
+ label -= 1
+ image = (image - 127.5) / 127.5
+ return (image, label)
+
@transform_registry(transform_type="ResizeCropImagenet", \
process="preprocess", framework="tensorflow")
class TensorflowResizeCropImagenetTransform(BaseTransform): # pragma: no cover