-
Notifications
You must be signed in to change notification settings - Fork 258
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
10 changed files
with
400 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
62 changes: 62 additions & 0 deletions
62
...age_recognition/tensorflow_models/vision_transformer/quantization/ptq/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
Step-by-Step | ||
============ | ||
|
||
This document list steps of reproducing Vision Transformer model tuning results via Neural Compressor. | ||
|
||
# Prerequisite | ||
|
||
## 1. Environment | ||
|
||
### Install Dependency Package | ||
|
||
``` | ||
pip install -r requirements.txt | ||
``` | ||
|
||
### Install Intel Extension for Tensorflow | ||
|
||
```shell | ||
pip install --upgrade intel-extension-for-tensorflow[cpu] | ||
``` | ||
|
||
## 2. Prepare Pretrained model | ||
|
||
``` | ||
wget https://storage.googleapis.com/intel-optimized-tensorflow/models/2_11_0/HF-ViT-Base16-Img224-frozen.pb | ||
``` | ||
|
||
## 3. Prepare Dataset | ||
|
||
TensorFlow [models](https://github.com/tensorflow/models) repo provides [scripts and instructions](https://github.com/tensorflow/models/tree/master/research/slim#an-automated-script-for-processing-imagenet-data) to download, process and convert the ImageNet dataset to the TF records format. | ||
We also prepared related scripts in ` examples/tensorflow/image_recognition/tensorflow_models/imagenet_prepare` directory. To download the raw images, the user must create an account with image-net.org. If you have downloaded the raw data and preprocessed the validation data by moving the images into the appropriate sub-directory based on the label (synset) of the image. we can use below command ro convert it to tf records format. | ||
|
||
```shell | ||
cd examples/tensorflow/image_recognition/tensorflow_models/ | ||
# convert validation subset | ||
bash prepare_dataset.sh --output_dir=./vision_transformer/quantization/ptq/data --raw_dir=/PATH/TO/img_raw/val/ --subset=validation | ||
# convert train subset | ||
bash prepare_dataset.sh --output_dir=./vision_transformer/quantization/ptq/data --raw_dir=/PATH/TO/img_raw/train/ --subset=train | ||
``` | ||
|
||
# Run | ||
|
||
## 1. Quantization | ||
|
||
``` | ||
bash run_tuning.sh --input_model <path to HF-ViT-Base16-Img224-frozen.pb> --output_model ./output --dataset_location <path to imagenet> | ||
``` | ||
|
||
|
||
## 2. Benchmark | ||
|
||
### Benchmark the fp32 model | ||
|
||
``` | ||
bash run_benchmark.sh --input_model=<path to HF-ViT-Base16-Img224-frozen.pb> --mode=accuracy --dataset_location=<path to imagenet> --batch_size=32 | ||
``` | ||
|
||
### Benchmark the int8 model | ||
|
||
``` | ||
bash run_benchmark.sh --input_model=./output.pb --mode=accuracy --dataset_location=<path to imagenet> --batch_size=32 --int8=true | ||
``` |
197 changes: 197 additions & 0 deletions
197
...ensorflow/image_recognition/tensorflow_models/vision_transformer/quantization/ptq/main.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,197 @@ | ||
# | ||
# -*- coding: utf-8 -*- | ||
# | ||
# Copyright (c) 2023 Intel Corporation | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# | ||
|
||
import time | ||
from argparse import ArgumentParser | ||
|
||
import tensorflow as tf | ||
from tensorflow.python.tools.optimize_for_inference_lib import optimize_for_inference | ||
from tensorflow.python.framework import dtypes | ||
from tensorflow.core.protobuf import saved_model_pb2 | ||
|
||
import numpy as np | ||
|
||
INPUTS = 'inputs' | ||
OUTPUTS = 'Identity' | ||
|
||
RESNET_IMAGE_SIZE = 224 | ||
IMAGENET_VALIDATION_IMAGES = 50000 | ||
|
||
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR) | ||
|
||
arg_parser = ArgumentParser(description='Parse args') | ||
arg_parser.add_argument('-g', "--input-graph", | ||
help='Specify the input graph for the transform tool', | ||
dest='input_graph') | ||
arg_parser.add_argument("--output-graph", | ||
help='Specify tune result model save dir', | ||
dest='output_graph') | ||
arg_parser.add_argument('--benchmark', dest='benchmark', action='store_true', help='run benchmark') | ||
arg_parser.add_argument('--mode', dest='mode', default='performance', help='benchmark mode') | ||
arg_parser.add_argument('--tune', dest='tune', action='store_true', help='use neural_compressor to tune.') | ||
arg_parser.add_argument('--diagnose', dest='diagnose', action='store_true', help='use Neural Insights to diagnose tuning and benchmark.') | ||
arg_parser.add_argument('--dataset_location', dest='dataset_location', | ||
help='location of calibration dataset and evaluate dataset') | ||
arg_parser.add_argument('--batch_size', type=int, default=32, dest='batch_size', help='batch_size of benchmark') | ||
arg_parser.add_argument('--iters', type=int, default=100, dest='iters', help='interations') | ||
arg_parser.add_argument('--int8', dest='int8', action='store_true', help='whether to use int8 model for benchmark') | ||
args = arg_parser.parse_args() | ||
|
||
def evaluate(model, eval_dataloader, metric, postprocess=None): | ||
"""Custom evaluate function to estimate the accuracy of the model. | ||
Args: | ||
model (tf.Graph_def): The input model graph | ||
Returns: | ||
accuracy (float): evaluation result, the larger is better. | ||
""" | ||
from neural_compressor.model import Model | ||
model = Model(model) | ||
input_tensor = model.input_tensor | ||
output_tensor = model.output_tensor if len(model.output_tensor)>1 else \ | ||
model.output_tensor[0] | ||
iteration = -1 | ||
if args.benchmark and args.mode == 'performance': | ||
iteration = args.iters | ||
|
||
def eval_func(dataloader): | ||
latency_list = [] | ||
for idx, (inputs, labels) in enumerate(dataloader): | ||
# shift the label and rescale the inputs | ||
inputs, labels = postprocess((inputs, labels)) | ||
# dataloader should keep the order and len of inputs same with input_tensor | ||
inputs = np.array([inputs]) | ||
feed_dict = dict(zip(input_tensor, inputs)) | ||
|
||
start = time.time() | ||
predictions = model.sess.run(output_tensor, feed_dict) | ||
end = time.time() | ||
|
||
if isinstance(predictions, list): | ||
if len(model.output_tensor_names) == 1: | ||
predictions = predictions[0] | ||
elif len(model.output_tensor_names) > 1: | ||
predictions = predictions[1] | ||
metric.update(predictions, labels) | ||
latency_list.append(end-start) | ||
if idx + 1 == iteration: | ||
break | ||
latency = np.array(latency_list).mean() / args.batch_size | ||
return latency | ||
|
||
latency = eval_func(eval_dataloader) | ||
if args.benchmark and args.mode == 'performance': | ||
print("Batch size = {}".format(args.batch_size)) | ||
print("Latency: {:.3f} ms".format(latency * 1000)) | ||
print("Throughput: {:.3f} images/sec".format(1. / latency)) | ||
acc = metric.result() | ||
return acc | ||
|
||
class eval_classifier_optimized_graph: | ||
"""Evaluate image classifier with optimized TensorFlow graph.""" | ||
|
||
def run(self): | ||
"""This is neural_compressor function include tuning, export and benchmark option.""" | ||
from neural_compressor import set_random_seed | ||
set_random_seed(9527) | ||
|
||
if args.tune: | ||
from neural_compressor import quantization | ||
from neural_compressor.config import PostTrainingQuantConfig, AccuracyCriterion | ||
from neural_compressor.utils.create_obj_from_config import create_dataloader | ||
calib_dataloader_args = { | ||
'batch_size': 10, | ||
'dataset': {"ImageRecord": {'root':args.dataset_location}}, | ||
'transform': {'ResizeCropImagenet':{'height': 224, 'width': 224}, | ||
'TransposeLastChannel':{}}, | ||
'filter': None | ||
} | ||
calib_dataloader = create_dataloader('tensorflow', calib_dataloader_args) | ||
eval_dataloader_args = { | ||
'batch_size': 32, | ||
'dataset': {"ImageRecord": {'root':args.dataset_location}}, | ||
'transform': {'ResizeCropImagenet': {'height': 224, 'width': 224}, | ||
'TransposeLastChannel':{}}, | ||
'filter': None | ||
} | ||
eval_dataloader = create_dataloader('tensorflow', eval_dataloader_args) | ||
|
||
conf = PostTrainingQuantConfig(calibration_sampling_size=[50, 100], | ||
accuracy_criterion = AccuracyCriterion(tolerable_loss=0.01), | ||
op_type_dict={'conv2d':{ 'weight':{'dtype':['fp32']}, 'activation':{'dtype':['fp32']} }}, | ||
backend='itex') | ||
from neural_compressor import METRICS | ||
metrics = METRICS('tensorflow') | ||
top1 = metrics['topk']() | ||
from tensorflow.core.protobuf import saved_model_pb2 | ||
sm = saved_model_pb2.SavedModel() | ||
with tf.io.gfile.GFile(args.input_graph, "rb") as f: | ||
sm.ParseFromString(f.read()) | ||
graph_def = sm.meta_graphs[0].graph_def | ||
from neural_compressor.data import TensorflowShiftRescale | ||
postprocess = TensorflowShiftRescale() | ||
def eval(model): | ||
return evaluate(model, eval_dataloader, top1, postprocess) | ||
q_model = quantization.fit(graph_def, conf=conf, calib_dataloader=calib_dataloader, | ||
# eval_dataloader=eval_dataloader, eval_metric=top1) | ||
eval_func=eval) | ||
q_model.save(args.output_graph) | ||
|
||
if args.benchmark: | ||
from neural_compressor.utils.create_obj_from_config import create_dataloader | ||
dataloader_args = { | ||
'batch_size': args.batch_size, | ||
'dataset': {"ImageRecord": {'root':args.dataset_location}}, | ||
'transform': {'ResizeCropImagenet': {'height': 224, 'width': 224}, | ||
'TransposeLastChannel':{}}, | ||
'filter': None | ||
} | ||
dataloader = create_dataloader('tensorflow', dataloader_args) | ||
from neural_compressor import METRICS | ||
metrics = METRICS('tensorflow') | ||
top1 = metrics['topk']() | ||
|
||
if args.int8 or args.input_graph.endswith("-tune.pb"): | ||
input_graph = args.input_graph | ||
else: | ||
from tensorflow.core.protobuf import saved_model_pb2 | ||
sm = saved_model_pb2.SavedModel() | ||
with tf.io.gfile.GFile(args.input_graph, "rb") as f: | ||
sm.ParseFromString(f.read()) | ||
graph_def = sm.meta_graphs[0].graph_def | ||
input_graph = graph_def | ||
|
||
from neural_compressor.data import TensorflowShiftRescale | ||
postprocess = TensorflowShiftRescale() | ||
def eval(model): | ||
return evaluate(model, dataloader, top1, postprocess) | ||
|
||
if args.mode == 'performance': | ||
from neural_compressor.benchmark import fit | ||
from neural_compressor.config import BenchmarkConfig | ||
conf = BenchmarkConfig(warmup=10, iteration=100, cores_per_instance=4, num_of_instance=1, backend='itex') | ||
fit(input_graph, conf, b_dataloader=dataloader) | ||
elif args.mode == 'accuracy': | ||
acc_result = eval(input_graph) | ||
print("Batch size = %d" % dataloader.batch_size) | ||
print("Accuracy: %.5f" % acc_result) | ||
|
||
if __name__ == "__main__": | ||
evaluate_opt_graph = eval_classifier_optimized_graph() | ||
evaluate_opt_graph.run() |
2 changes: 2 additions & 0 deletions
2
.../image_recognition/tensorflow_models/vision_transformer/quantization/ptq/requirements.txt
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
tensorflow==2.11.0 | ||
neural-compressor |
57 changes: 57 additions & 0 deletions
57
.../image_recognition/tensorflow_models/vision_transformer/quantization/ptq/run_benchmark.sh
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
#!/bin/bash | ||
set -x | ||
|
||
function main { | ||
|
||
init_params "$@" | ||
run_benchmark | ||
|
||
} | ||
|
||
# init params | ||
function init_params { | ||
batch_size=32 | ||
iters=100 | ||
|
||
for var in "$@" | ||
do | ||
case $var in | ||
--input_model=*) | ||
input_model=$(echo $var |cut -f2 -d=) | ||
;; | ||
--mode=*) | ||
mode=$(echo $var |cut -f2 -d=) | ||
;; | ||
--dataset_location=*) | ||
dataset_location=$(echo $var |cut -f2 -d=) | ||
;; | ||
--batch_size=*) | ||
batch_size=$(echo $var |cut -f2 -d=) | ||
;; | ||
--iters=*) | ||
iters=$(echo $var |cut -f2 -d=) | ||
;; | ||
--int8=*) | ||
int8=$(echo $var |cut -f2 -d=) | ||
;; | ||
esac | ||
done | ||
|
||
} | ||
|
||
# run_tuning | ||
function run_benchmark { | ||
if [[ ${int8} == "true" ]]; then | ||
extra_cmd=$extra_cmd" --int8" | ||
fi | ||
python main.py \ | ||
--input-graph ${input_model} \ | ||
--mode ${mode} \ | ||
--dataset_location ${dataset_location} \ | ||
--batch_size ${batch_size} \ | ||
--benchmark \ | ||
--iters ${iters} \ | ||
${extra_cmd} | ||
} | ||
|
||
main "$@" |
39 changes: 39 additions & 0 deletions
39
...low/image_recognition/tensorflow_models/vision_transformer/quantization/ptq/run_tuning.sh
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
#!/bin/bash | ||
set -x | ||
|
||
function main { | ||
init_params "$@" | ||
run_tuning | ||
|
||
} | ||
|
||
# init params | ||
function init_params { | ||
|
||
for var in "$@" | ||
do | ||
case $var in | ||
--input_model=*) | ||
input_model=$(echo $var |cut -f2 -d=) | ||
;; | ||
--output_model=*) | ||
output_model=$(echo $var |cut -f2 -d=) | ||
;; | ||
--dataset_location=*) | ||
dataset_location=$(echo $var |cut -f2 -d=) | ||
;; | ||
esac | ||
done | ||
|
||
} | ||
|
||
# run_tuning | ||
function run_tuning { | ||
python main.py \ | ||
--input-graph ${input_model} \ | ||
--output-graph ${output_model} \ | ||
--dataset_location ${dataset_location} \ | ||
--tune | ||
} | ||
|
||
main "$@" |
Oops, something went wrong.