Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable ViT #998

Merged
merged 10 commits into from
Jul 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions examples/.config/model_params_tensorflow.json
Original file line number Diff line number Diff line change
Expand Up @@ -1800,6 +1800,13 @@
"input_model": "/tf_dataset2/models/tensorflow/facebook-opt-125m",
"main_script": "main.py",
"batch_size": 16
},
"ViT": {
"model_src_dir": "image_recognition/tensorflow_models/vision_transformer/quantization/ptq",
"dataset_location": "/tf_dataset/dataset/imagenet",
"input_model": "/tf_dataset/tensorflow/vit/HF-ViT-Base16-Img224-frozen.pb",
"main_script": "main.py",
"batch_size": 32
}
}
}
Expand Down
6 changes: 6 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,12 @@ Intel® Neural Compressor validated examples with multiple compression technique
<td>Post-Training Static Quantization</td>
<td><a href="./tensorflow/nlp/large_language_models/quantization/ptq/smoothquant">pb (smooth quant)</a></td>
</tr>
<tr>
<td>ViT</td>
<td>Image Recognition</td>
<td>Post-Training Static Quantization</td>
<td><a href="./tensorflow/image_recognition/tensorflow_models/vision_transformer/">pb</a></td>
</tr>
</tbody>
</table>

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
Step-by-Step
============

This document list steps of reproducing Vision Transformer model tuning results via Neural Compressor.

# Prerequisite

## 1. Environment

### Install Dependency Package

```
pip install -r requirements.txt
```

### Install Intel Extension for Tensorflow

```shell
pip install --upgrade intel-extension-for-tensorflow[cpu]
```

## 2. Prepare Pretrained model

```
wget https://storage.googleapis.com/intel-optimized-tensorflow/models/2_11_0/HF-ViT-Base16-Img224-frozen.pb
```

## 3. Prepare Dataset

TensorFlow [models](https://github.com/tensorflow/models) repo provides [scripts and instructions](https://github.com/tensorflow/models/tree/master/research/slim#an-automated-script-for-processing-imagenet-data) to download, process and convert the ImageNet dataset to the TF records format.
We also prepared related scripts in ` examples/tensorflow/image_recognition/tensorflow_models/imagenet_prepare` directory. To download the raw images, the user must create an account with image-net.org. If you have downloaded the raw data and preprocessed the validation data by moving the images into the appropriate sub-directory based on the label (synset) of the image. we can use below command ro convert it to tf records format.

```shell
cd examples/tensorflow/image_recognition/tensorflow_models/
# convert validation subset
bash prepare_dataset.sh --output_dir=./vision_transformer/quantization/ptq/data --raw_dir=/PATH/TO/img_raw/val/ --subset=validation
# convert train subset
bash prepare_dataset.sh --output_dir=./vision_transformer/quantization/ptq/data --raw_dir=/PATH/TO/img_raw/train/ --subset=train
```

# Run

## 1. Quantization

```
bash run_tuning.sh --input_model <path to HF-ViT-Base16-Img224-frozen.pb> --output_model ./output --dataset_location <path to imagenet>
```


## 2. Benchmark

### Benchmark the fp32 model

```
bash run_benchmark.sh --input_model=<path to HF-ViT-Base16-Img224-frozen.pb> --mode=accuracy --dataset_location=<path to imagenet> --batch_size=32
```

### Benchmark the int8 model

```
bash run_benchmark.sh --input_model=./output.pb --mode=accuracy --dataset_location=<path to imagenet> --batch_size=32 --int8=true
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
#
# -*- coding: utf-8 -*-
#
# Copyright (c) 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import time
from argparse import ArgumentParser

import tensorflow as tf
from tensorflow.python.tools.optimize_for_inference_lib import optimize_for_inference
from tensorflow.python.framework import dtypes
from tensorflow.core.protobuf import saved_model_pb2

import numpy as np

INPUTS = 'inputs'
OUTPUTS = 'Identity'

RESNET_IMAGE_SIZE = 224
IMAGENET_VALIDATION_IMAGES = 50000

tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)

arg_parser = ArgumentParser(description='Parse args')
arg_parser.add_argument('-g', "--input-graph",
help='Specify the input graph for the transform tool',
dest='input_graph')
arg_parser.add_argument("--output-graph",
help='Specify tune result model save dir',
dest='output_graph')
arg_parser.add_argument('--benchmark', dest='benchmark', action='store_true', help='run benchmark')
arg_parser.add_argument('--mode', dest='mode', default='performance', help='benchmark mode')
arg_parser.add_argument('--tune', dest='tune', action='store_true', help='use neural_compressor to tune.')
arg_parser.add_argument('--diagnose', dest='diagnose', action='store_true', help='use Neural Insights to diagnose tuning and benchmark.')
arg_parser.add_argument('--dataset_location', dest='dataset_location',
help='location of calibration dataset and evaluate dataset')
arg_parser.add_argument('--batch_size', type=int, default=32, dest='batch_size', help='batch_size of benchmark')
arg_parser.add_argument('--iters', type=int, default=100, dest='iters', help='interations')
arg_parser.add_argument('--int8', dest='int8', action='store_true', help='whether to use int8 model for benchmark')
args = arg_parser.parse_args()

def evaluate(model, eval_dataloader, metric, postprocess=None):
"""Custom evaluate function to estimate the accuracy of the model.

Args:
model (tf.Graph_def): The input model graph

Returns:
accuracy (float): evaluation result, the larger is better.
"""
from neural_compressor.model import Model
model = Model(model)
input_tensor = model.input_tensor
output_tensor = model.output_tensor if len(model.output_tensor)>1 else \
model.output_tensor[0]
iteration = -1
if args.benchmark and args.mode == 'performance':
iteration = args.iters

def eval_func(dataloader):
latency_list = []
for idx, (inputs, labels) in enumerate(dataloader):
# shift the label and rescale the inputs
inputs, labels = postprocess((inputs, labels))
# dataloader should keep the order and len of inputs same with input_tensor
inputs = np.array([inputs])
feed_dict = dict(zip(input_tensor, inputs))

start = time.time()
predictions = model.sess.run(output_tensor, feed_dict)
end = time.time()

if isinstance(predictions, list):
if len(model.output_tensor_names) == 1:
predictions = predictions[0]
elif len(model.output_tensor_names) > 1:
predictions = predictions[1]
metric.update(predictions, labels)
latency_list.append(end-start)
if idx + 1 == iteration:
break
latency = np.array(latency_list).mean() / args.batch_size
return latency

latency = eval_func(eval_dataloader)
if args.benchmark and args.mode == 'performance':
print("Batch size = {}".format(args.batch_size))
print("Latency: {:.3f} ms".format(latency * 1000))
print("Throughput: {:.3f} images/sec".format(1. / latency))
acc = metric.result()
return acc

class eval_classifier_optimized_graph:
"""Evaluate image classifier with optimized TensorFlow graph."""

def run(self):
"""This is neural_compressor function include tuning, export and benchmark option."""
from neural_compressor import set_random_seed
set_random_seed(9527)

if args.tune:
from neural_compressor import quantization
from neural_compressor.config import PostTrainingQuantConfig, AccuracyCriterion
from neural_compressor.utils.create_obj_from_config import create_dataloader
calib_dataloader_args = {
'batch_size': 10,
'dataset': {"ImageRecord": {'root':args.dataset_location}},
'transform': {'ResizeCropImagenet':{'height': 224, 'width': 224},
'TransposeLastChannel':{}},
'filter': None
}
calib_dataloader = create_dataloader('tensorflow', calib_dataloader_args)
eval_dataloader_args = {
'batch_size': 32,
'dataset': {"ImageRecord": {'root':args.dataset_location}},
'transform': {'ResizeCropImagenet': {'height': 224, 'width': 224},
'TransposeLastChannel':{}},
'filter': None
}
eval_dataloader = create_dataloader('tensorflow', eval_dataloader_args)

conf = PostTrainingQuantConfig(calibration_sampling_size=[50, 100],
accuracy_criterion = AccuracyCriterion(tolerable_loss=0.01),
op_type_dict={'conv2d':{ 'weight':{'dtype':['fp32']}, 'activation':{'dtype':['fp32']} }},
backend='itex')
from neural_compressor import METRICS
metrics = METRICS('tensorflow')
top1 = metrics['topk']()
from tensorflow.core.protobuf import saved_model_pb2
sm = saved_model_pb2.SavedModel()
with tf.io.gfile.GFile(args.input_graph, "rb") as f:
sm.ParseFromString(f.read())
graph_def = sm.meta_graphs[0].graph_def
from neural_compressor.data import TensorflowShiftRescale
postprocess = TensorflowShiftRescale()
def eval(model):
return evaluate(model, eval_dataloader, top1, postprocess)
q_model = quantization.fit(graph_def, conf=conf, calib_dataloader=calib_dataloader,
# eval_dataloader=eval_dataloader, eval_metric=top1)
eval_func=eval)
q_model.save(args.output_graph)

if args.benchmark:
from neural_compressor.utils.create_obj_from_config import create_dataloader
dataloader_args = {
'batch_size': args.batch_size,
'dataset': {"ImageRecord": {'root':args.dataset_location}},
'transform': {'ResizeCropImagenet': {'height': 224, 'width': 224},
'TransposeLastChannel':{}},
'filter': None
}
dataloader = create_dataloader('tensorflow', dataloader_args)
from neural_compressor import METRICS
metrics = METRICS('tensorflow')
top1 = metrics['topk']()

if args.int8 or args.input_graph.endswith("-tune.pb"):
input_graph = args.input_graph
else:
from tensorflow.core.protobuf import saved_model_pb2
sm = saved_model_pb2.SavedModel()
with tf.io.gfile.GFile(args.input_graph, "rb") as f:
sm.ParseFromString(f.read())
graph_def = sm.meta_graphs[0].graph_def
input_graph = graph_def

from neural_compressor.data import TensorflowShiftRescale
postprocess = TensorflowShiftRescale()
def eval(model):
return evaluate(model, dataloader, top1, postprocess)

if args.mode == 'performance':
from neural_compressor.benchmark import fit
from neural_compressor.config import BenchmarkConfig
conf = BenchmarkConfig(warmup=10, iteration=100, cores_per_instance=4, num_of_instance=1, backend='itex')
fit(input_graph, conf, b_dataloader=dataloader)
elif args.mode == 'accuracy':
acc_result = eval(input_graph)
print("Batch size = %d" % dataloader.batch_size)
print("Accuracy: %.5f" % acc_result)

if __name__ == "__main__":
evaluate_opt_graph = eval_classifier_optimized_graph()
evaluate_opt_graph.run()
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
tensorflow==2.11.0
neural-compressor
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#!/bin/bash
set -x

function main {

init_params "$@"
run_benchmark

}

# init params
function init_params {
batch_size=32
iters=100

for var in "$@"
do
case $var in
--input_model=*)
input_model=$(echo $var |cut -f2 -d=)
;;
--mode=*)
mode=$(echo $var |cut -f2 -d=)
;;
--dataset_location=*)
dataset_location=$(echo $var |cut -f2 -d=)
;;
--batch_size=*)
batch_size=$(echo $var |cut -f2 -d=)
;;
--iters=*)
iters=$(echo $var |cut -f2 -d=)
;;
--int8=*)
int8=$(echo $var |cut -f2 -d=)
;;
esac
done

}

# run_tuning
function run_benchmark {
if [[ ${int8} == "true" ]]; then
extra_cmd=$extra_cmd" --int8"
fi
python main.py \
--input-graph ${input_model} \
--mode ${mode} \
--dataset_location ${dataset_location} \
--batch_size ${batch_size} \
--benchmark \
--iters ${iters} \
${extra_cmd}
}

main "$@"
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#!/bin/bash
set -x

function main {
init_params "$@"
run_tuning

}

# init params
function init_params {

for var in "$@"
do
case $var in
--input_model=*)
input_model=$(echo $var |cut -f2 -d=)
;;
--output_model=*)
output_model=$(echo $var |cut -f2 -d=)
;;
--dataset_location=*)
dataset_location=$(echo $var |cut -f2 -d=)
;;
esac
done

}

# run_tuning
function run_tuning {
python main.py \
--input-graph ${input_model} \
--output-graph ${output_model} \
--dataset_location ${dataset_location} \
--tune
}

main "$@"
Loading