Skip to content

Commit

Permalink
Enable ViT with ITEX backend (#998)
Browse files Browse the repository at this point in the history
  • Loading branch information
Spycsh authored Jul 4, 2023
1 parent 78181cd commit 94df997
Show file tree
Hide file tree
Showing 10 changed files with 400 additions and 2 deletions.
7 changes: 7 additions & 0 deletions examples/.config/model_params_tensorflow.json
Original file line number Diff line number Diff line change
Expand Up @@ -1800,6 +1800,13 @@
"input_model": "/tf_dataset2/models/tensorflow/facebook-opt-125m",
"main_script": "main.py",
"batch_size": 16
},
"ViT": {
"model_src_dir": "image_recognition/tensorflow_models/vision_transformer/quantization/ptq",
"dataset_location": "/tf_dataset/dataset/imagenet",
"input_model": "/tf_dataset/tensorflow/vit/HF-ViT-Base16-Img224-frozen.pb",
"main_script": "main.py",
"batch_size": 32
}
}
}
Expand Down
6 changes: 6 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,12 @@ Intel® Neural Compressor validated examples with multiple compression technique
<td>Post-Training Static Quantization</td>
<td><a href="./tensorflow/nlp/large_language_models/quantization/ptq/smoothquant">pb (smooth quant)</a></td>
</tr>
<tr>
<td>ViT</td>
<td>Image Recognition</td>
<td>Post-Training Static Quantization</td>
<td><a href="./tensorflow/image_recognition/tensorflow_models/vision_transformer/">pb</a></td>
</tr>
</tbody>
</table>

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
Step-by-Step
============

This document list steps of reproducing Vision Transformer model tuning results via Neural Compressor.

# Prerequisite

## 1. Environment

### Install Dependency Package

```
pip install -r requirements.txt
```

### Install Intel Extension for Tensorflow

```shell
pip install --upgrade intel-extension-for-tensorflow[cpu]
```

## 2. Prepare Pretrained model

```
wget https://storage.googleapis.com/intel-optimized-tensorflow/models/2_11_0/HF-ViT-Base16-Img224-frozen.pb
```

## 3. Prepare Dataset

TensorFlow [models](https://github.com/tensorflow/models) repo provides [scripts and instructions](https://github.com/tensorflow/models/tree/master/research/slim#an-automated-script-for-processing-imagenet-data) to download, process and convert the ImageNet dataset to the TF records format.
We also prepared related scripts in ` examples/tensorflow/image_recognition/tensorflow_models/imagenet_prepare` directory. To download the raw images, the user must create an account with image-net.org. If you have downloaded the raw data and preprocessed the validation data by moving the images into the appropriate sub-directory based on the label (synset) of the image. we can use below command ro convert it to tf records format.

```shell
cd examples/tensorflow/image_recognition/tensorflow_models/
# convert validation subset
bash prepare_dataset.sh --output_dir=./vision_transformer/quantization/ptq/data --raw_dir=/PATH/TO/img_raw/val/ --subset=validation
# convert train subset
bash prepare_dataset.sh --output_dir=./vision_transformer/quantization/ptq/data --raw_dir=/PATH/TO/img_raw/train/ --subset=train
```

# Run

## 1. Quantization

```
bash run_tuning.sh --input_model <path to HF-ViT-Base16-Img224-frozen.pb> --output_model ./output --dataset_location <path to imagenet>
```


## 2. Benchmark

### Benchmark the fp32 model

```
bash run_benchmark.sh --input_model=<path to HF-ViT-Base16-Img224-frozen.pb> --mode=accuracy --dataset_location=<path to imagenet> --batch_size=32
```

### Benchmark the int8 model

```
bash run_benchmark.sh --input_model=./output.pb --mode=accuracy --dataset_location=<path to imagenet> --batch_size=32 --int8=true
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
#
# -*- coding: utf-8 -*-
#
# Copyright (c) 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import time
from argparse import ArgumentParser

import tensorflow as tf
from tensorflow.python.tools.optimize_for_inference_lib import optimize_for_inference
from tensorflow.python.framework import dtypes
from tensorflow.core.protobuf import saved_model_pb2

import numpy as np

INPUTS = 'inputs'
OUTPUTS = 'Identity'

RESNET_IMAGE_SIZE = 224
IMAGENET_VALIDATION_IMAGES = 50000

tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

arg_parser = ArgumentParser(description='Parse args')
arg_parser.add_argument('-g', "--input-graph",
help='Specify the input graph for the transform tool',
dest='input_graph')
arg_parser.add_argument("--output-graph",
help='Specify tune result model save dir',
dest='output_graph')
arg_parser.add_argument('--benchmark', dest='benchmark', action='store_true', help='run benchmark')
arg_parser.add_argument('--mode', dest='mode', default='performance', help='benchmark mode')
arg_parser.add_argument('--tune', dest='tune', action='store_true', help='use neural_compressor to tune.')
arg_parser.add_argument('--diagnose', dest='diagnose', action='store_true', help='use Neural Insights to diagnose tuning and benchmark.')
arg_parser.add_argument('--dataset_location', dest='dataset_location',
help='location of calibration dataset and evaluate dataset')
arg_parser.add_argument('--batch_size', type=int, default=32, dest='batch_size', help='batch_size of benchmark')
arg_parser.add_argument('--iters', type=int, default=100, dest='iters', help='interations')
arg_parser.add_argument('--int8', dest='int8', action='store_true', help='whether to use int8 model for benchmark')
args = arg_parser.parse_args()

def evaluate(model, eval_dataloader, metric, postprocess=None):
"""Custom evaluate function to estimate the accuracy of the model.
Args:
model (tf.Graph_def): The input model graph
Returns:
accuracy (float): evaluation result, the larger is better.
"""
from neural_compressor.model import Model
model = Model(model)
input_tensor = model.input_tensor
output_tensor = model.output_tensor if len(model.output_tensor)>1 else \
model.output_tensor[0]
iteration = -1
if args.benchmark and args.mode == 'performance':
iteration = args.iters

def eval_func(dataloader):
latency_list = []
for idx, (inputs, labels) in enumerate(dataloader):
# shift the label and rescale the inputs
inputs, labels = postprocess((inputs, labels))
# dataloader should keep the order and len of inputs same with input_tensor
inputs = np.array([inputs])
feed_dict = dict(zip(input_tensor, inputs))

start = time.time()
predictions = model.sess.run(output_tensor, feed_dict)
end = time.time()

if isinstance(predictions, list):
if len(model.output_tensor_names) == 1:
predictions = predictions[0]
elif len(model.output_tensor_names) > 1:
predictions = predictions[1]
metric.update(predictions, labels)
latency_list.append(end-start)
if idx + 1 == iteration:
break
latency = np.array(latency_list).mean() / args.batch_size
return latency

latency = eval_func(eval_dataloader)
if args.benchmark and args.mode == 'performance':
print("Batch size = {}".format(args.batch_size))
print("Latency: {:.3f} ms".format(latency * 1000))
print("Throughput: {:.3f} images/sec".format(1. / latency))
acc = metric.result()
return acc

class eval_classifier_optimized_graph:
"""Evaluate image classifier with optimized TensorFlow graph."""

def run(self):
"""This is neural_compressor function include tuning, export and benchmark option."""
from neural_compressor import set_random_seed
set_random_seed(9527)

if args.tune:
from neural_compressor import quantization
from neural_compressor.config import PostTrainingQuantConfig, AccuracyCriterion
from neural_compressor.utils.create_obj_from_config import create_dataloader
calib_dataloader_args = {
'batch_size': 10,
'dataset': {"ImageRecord": {'root':args.dataset_location}},
'transform': {'ResizeCropImagenet':{'height': 224, 'width': 224},
'TransposeLastChannel':{}},
'filter': None
}
calib_dataloader = create_dataloader('tensorflow', calib_dataloader_args)
eval_dataloader_args = {
'batch_size': 32,
'dataset': {"ImageRecord": {'root':args.dataset_location}},
'transform': {'ResizeCropImagenet': {'height': 224, 'width': 224},
'TransposeLastChannel':{}},
'filter': None
}
eval_dataloader = create_dataloader('tensorflow', eval_dataloader_args)

conf = PostTrainingQuantConfig(calibration_sampling_size=[50, 100],
accuracy_criterion = AccuracyCriterion(tolerable_loss=0.01),
op_type_dict={'conv2d':{ 'weight':{'dtype':['fp32']}, 'activation':{'dtype':['fp32']} }},
backend='itex')
from neural_compressor import METRICS
metrics = METRICS('tensorflow')
top1 = metrics['topk']()
from tensorflow.core.protobuf import saved_model_pb2
sm = saved_model_pb2.SavedModel()
with tf.io.gfile.GFile(args.input_graph, "rb") as f:
sm.ParseFromString(f.read())
graph_def = sm.meta_graphs[0].graph_def
from neural_compressor.data import TensorflowShiftRescale
postprocess = TensorflowShiftRescale()
def eval(model):
return evaluate(model, eval_dataloader, top1, postprocess)
q_model = quantization.fit(graph_def, conf=conf, calib_dataloader=calib_dataloader,
# eval_dataloader=eval_dataloader, eval_metric=top1)
eval_func=eval)
q_model.save(args.output_graph)

if args.benchmark:
from neural_compressor.utils.create_obj_from_config import create_dataloader
dataloader_args = {
'batch_size': args.batch_size,
'dataset': {"ImageRecord": {'root':args.dataset_location}},
'transform': {'ResizeCropImagenet': {'height': 224, 'width': 224},
'TransposeLastChannel':{}},
'filter': None
}
dataloader = create_dataloader('tensorflow', dataloader_args)
from neural_compressor import METRICS
metrics = METRICS('tensorflow')
top1 = metrics['topk']()

if args.int8 or args.input_graph.endswith("-tune.pb"):
input_graph = args.input_graph
else:
from tensorflow.core.protobuf import saved_model_pb2
sm = saved_model_pb2.SavedModel()
with tf.io.gfile.GFile(args.input_graph, "rb") as f:
sm.ParseFromString(f.read())
graph_def = sm.meta_graphs[0].graph_def
input_graph = graph_def

from neural_compressor.data import TensorflowShiftRescale
postprocess = TensorflowShiftRescale()
def eval(model):
return evaluate(model, dataloader, top1, postprocess)

if args.mode == 'performance':
from neural_compressor.benchmark import fit
from neural_compressor.config import BenchmarkConfig
conf = BenchmarkConfig(warmup=10, iteration=100, cores_per_instance=4, num_of_instance=1, backend='itex')
fit(input_graph, conf, b_dataloader=dataloader)
elif args.mode == 'accuracy':
acc_result = eval(input_graph)
print("Batch size = %d" % dataloader.batch_size)
print("Accuracy: %.5f" % acc_result)

if __name__ == "__main__":
evaluate_opt_graph = eval_classifier_optimized_graph()
evaluate_opt_graph.run()
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
tensorflow==2.11.0
neural-compressor
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#!/bin/bash
set -x

function main {

init_params "$@"
run_benchmark

}

# init params
function init_params {
batch_size=32
iters=100

for var in "$@"
do
case $var in
--input_model=*)
input_model=$(echo $var |cut -f2 -d=)
;;
--mode=*)
mode=$(echo $var |cut -f2 -d=)
;;
--dataset_location=*)
dataset_location=$(echo $var |cut -f2 -d=)
;;
--batch_size=*)
batch_size=$(echo $var |cut -f2 -d=)
;;
--iters=*)
iters=$(echo $var |cut -f2 -d=)
;;
--int8=*)
int8=$(echo $var |cut -f2 -d=)
;;
esac
done

}

# run_tuning
function run_benchmark {
if [[ ${int8} == "true" ]]; then
extra_cmd=$extra_cmd" --int8"
fi
python main.py \
--input-graph ${input_model} \
--mode ${mode} \
--dataset_location ${dataset_location} \
--batch_size ${batch_size} \
--benchmark \
--iters ${iters} \
${extra_cmd}
}

main "$@"
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#!/bin/bash
set -x

function main {
init_params "$@"
run_tuning

}

# init params
function init_params {

for var in "$@"
do
case $var in
--input_model=*)
input_model=$(echo $var |cut -f2 -d=)
;;
--output_model=*)
output_model=$(echo $var |cut -f2 -d=)
;;
--dataset_location=*)
dataset_location=$(echo $var |cut -f2 -d=)
;;
esac
done

}

# run_tuning
function run_tuning {
python main.py \
--input-graph ${input_model} \
--output-graph ${output_model} \
--dataset_location ${dataset_location} \
--tune
}

main "$@"
Loading

0 comments on commit 94df997

Please sign in to comment.