-
Notifications
You must be signed in to change notification settings - Fork 259
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor Quantization Aware Training of TF backend (#250)
Signed-off-by: zehao-intel <[email protected]>
- Loading branch information
1 parent
c61be34
commit 1deb7d2
Showing
36 changed files
with
2,064 additions
and
300 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
114 changes: 88 additions & 26 deletions
114
...ples/tensorflow/image_recognition/keras_models/mnist/quantization/qat/README.md
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,56 +1,118 @@ | ||
Step-by-Step | ||
============ | ||
|
||
This document is used to list steps of reproducing TensorFlow keras Intel® Neural Compressor QAT conversion. | ||
This document is used to apply QAT to Tensorflow Keras models using Intel® Neural Compressor. | ||
This example can run on Intel CPUs and GPUs. | ||
|
||
|
||
## Prerequisite | ||
|
||
### 1. Installation | ||
```shell | ||
# Install Intel® Neural Compressor | ||
pip install neural-compressor | ||
``` | ||
### 2. Install Intel Tensorflow and TensorFlow Model Optimization | ||
### 2. Install requirements | ||
The Tensorflow and intel-extension-for-tensorflow is mandatory to be installed to run this QAT example. | ||
The Intel Extension for Tensorflow for Intel CPUs is installed as default. | ||
```shell | ||
pip install intel-tensorflow==2.4.0 | ||
pip install tensorflow_model_optimization==0.5.0 | ||
pip install -r requirements.txt | ||
``` | ||
> Note: To generate correct qat model with tensorflow_model_optimization 0.5.0, pls use TensorFlow 2.4 or above. | ||
> Note: Supported Tensorflow [Version](../../../../../../../README.md). | ||
### 3. Install Intel Extension for Tensorflow | ||
### 3. Benchmarking the model on Intel GPU (Optional) | ||
|
||
#### Quantizing the model on Intel GPU | ||
Intel Extension for Tensorflow is mandatory to be installed for quantizing the model on Intel GPUs. | ||
To run benchmark of the model on Intel GPUs, Intel Extension for Tensorflow for Intel GPUs is required. | ||
|
||
```shell | ||
pip install --upgrade intel-extension-for-tensorflow[gpu] | ||
``` | ||
For any more details, please follow the procedure in [install-gpu-drivers](https://github.com/intel-innersource/frameworks.ai.infrastructure.intel-extension-for-tensorflow.intel-extension-for-tensorflow/blob/master/docs/install/install_for_gpu.md#install-gpu-drivers) | ||
|
||
#### Quantizing the model on Intel CPU(Experimental) | ||
Intel Extension for Tensorflow for Intel CPUs is experimental currently. It's not mandatory for quantizing the model on Intel CPUs. | ||
|
||
```shell | ||
pip install --upgrade intel-extension-for-tensorflow[cpu] | ||
``` | ||
Please refer to the [Installation Guides](https://dgpu-docs.intel.com/installation-guides/ubuntu/ubuntu-focal-dc.html) for latest Intel GPU driver installation. | ||
For any more details, please follow the procedure in [install-gpu-drivers](https://github.com/intel-innersource/frameworks.ai.infrastructure.intel-extension-for-tensorflow.intel-extension-for-tensorflow/blob/master/docs/install/install_for_gpu.md#install-gpu-drivers). | ||
|
||
### 4. Prepare Pretrained model | ||
|
||
Run the `train.py` script to get pretrained fp32 model. | ||
The pretrained model is provided by [Keras Applications](https://keras.io/api/applications/). prepare the model, Run as follow: | ||
``` | ||
### 5. Prepare QAT model | ||
|
||
Run the `qat.py` script to get QAT model which in fact is a fp32 model with quant/dequant pair inserted. | ||
|
||
## Write Yaml config file | ||
In examples directory, there is a mnist.yaml for tuning the model on Intel CPUs. The 'framework' in the yaml is set to 'tensorflow'. If running this example on Intel GPUs, the 'framework' should be set to 'tensorflow_itex' and the device in yaml file should be set to 'gpu'. The mnist_itex.yaml is prepared for the GPU case. We could remove most of items and only keep mandatory item for tuning. We also implement a calibration dataloader and have evaluation field for creation of evaluation function at internal neural_compressor. | ||
python prepare_model.py --output_model=/path/to/model | ||
``` | ||
`--output_model ` the model should be saved as SavedModel format or H5 format. | ||
|
||
## Run Command | ||
```shell | ||
python convert.py # to convert QAT model to quantized model. | ||
|
||
python benchmark.py # to run accuracy benchmark. | ||
bash run_tuning.sh --input_model=./path/to/model --output_model=./result | ||
bash run_benchmark.sh --input_model=./path/to/model --mode=performance --batch_size=32 | ||
``` | ||
|
||
Details of enabling Intel® Neural Compressor to apply QAT. | ||
========================= | ||
|
||
This is a tutorial of how to to apply QAT with Intel® Neural Compressor. | ||
## User Code Analysis | ||
1. User specifies fp32 *model* to apply quantization, the dataset is automatically downloaded. In this step, QDQ patterns will be inserted to the keras model, but the fp32 model will not be converted to a int8 model. | ||
|
||
2. User specifies *model* with QDQ patterns inserted, evaluate function to run benchmark. The model we get from the previous step will be run on ITEX backend. Then, the model is going to be fused and inferred. | ||
|
||
### Quantization Config | ||
The Quantization Config class has default parameters setting for running on Intel CPUs. If running this example on Intel GPUs, the 'backend' parameter should be set to 'itex' and the 'device' parameter should be set to 'gpu'. | ||
|
||
``` | ||
config = QuantizationAwareTrainingConfig( | ||
device="gpu", | ||
backend="itex", | ||
... | ||
) | ||
``` | ||
|
||
### Code update | ||
|
||
After prepare step is done, we add quantization and benchmark code to generate quantized model and benchmark. | ||
|
||
#### Tune | ||
```python | ||
logger.info('start quantizing the model...') | ||
from neural_compressor import training, QuantizationAwareTrainingConfig | ||
config = QuantizationAwareTrainingConfig() | ||
# create a compression_manager instance to implement QAT | ||
compression_manager = training.prepare_compression(FLAGS.input_model, config) | ||
# QDQ patterns will be inserted to the input keras model | ||
compression_manager.callbacks.on_train_begin() | ||
# get the model with QDQ patterns inserted | ||
q_aware_model = compression_manager.model.model | ||
|
||
# training code defined by users | ||
q_aware_model.compile(optimizer='adam', | ||
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), | ||
metrics=['accuracy']) | ||
q_aware_model.summary() | ||
train_images_subset = train_images[0:1000] | ||
train_labels_subset = train_labels[0:1000] | ||
q_aware_model.fit(train_images_subset, train_labels_subset, | ||
batch_size=500, epochs=1, validation_split=0.1) | ||
_, q_aware_model_accuracy = q_aware_model.evaluate( | ||
test_images, test_labels, verbose=0) | ||
print('Quant test accuracy:', q_aware_model_accuracy) | ||
|
||
# apply some post process steps and save the output model | ||
compression_manager.callbacks.on_train_end() | ||
compression_manager.save(FLAGS.output_model) | ||
``` | ||
#### Benchmark | ||
```python | ||
from neural_compressor.benchmark import fit | ||
from neural_compressor.experimental import common | ||
from neural_compressor.config import BenchmarkConfig | ||
assert FLAGS.mode == 'performance' or FLAGS.mode == 'accuracy', \ | ||
"Benchmark only supports performance or accuracy mode." | ||
|
||
# convert the quantized keras model to graph_def so that it can be fused by ITEX | ||
model = common.Model(FLAGS.input_model).graph_def | ||
if FLAGS.mode == 'performance': | ||
conf = BenchmarkConfig(cores_per_instance=4, num_of_instance=7) | ||
fit(model, conf, b_func=evaluate) | ||
elif FLAGS.mode == 'accuracy': | ||
accuracy = evaluate(model) | ||
print('Batch size = %d' % FLAGS.batch_size) | ||
print("Accuracy: %.5f" % accuracy) | ||
``` |
28 changes: 0 additions & 28 deletions
28
examples/tensorflow/image_recognition/keras_models/mnist/quantization/qat/benchmark.py
This file was deleted.
Oops, something went wrong.
7 changes: 0 additions & 7 deletions
7
examples/tensorflow/image_recognition/keras_models/mnist/quantization/qat/convert.py
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.