Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Nano : reduce time cost for InferenceOptimizer and update demo #5740

Merged
merged 13 commits into from
Sep 15, 2022
55 changes: 33 additions & 22 deletions python/nano/example/pytorch/inference_pipeline/resnet/README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Bigdl-nano InferenceOptimizer example on Cat vs. Dog dataset

This example illustrates how to apply InferenceOptimizer to quickly find acceleration method with the minimum inference latency under specific restrictions or without restrictions for a trained model.
For the sake of this example, we first train the proposed network(by default, a ResNet18 is used) on the [cats and dogs dataset](https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip), which consists both [frozen and unfrozen stages](https://github.com/PyTorchLightning/pytorch-lightning/blob/495812878dfe2e31ec2143c071127990afbb082b/pl_examples/domain_templates/computer_vision_fine_tuning.py#L21-L35). Then, by calling `optimize()`, we can obtain all available accelaration combinations provided by BigDL-Nano for inference. By calling `get_best_mdoel()` , we could get an accelerated model whose inference is 7.5x times faster.
For the sake of this example, we first train the proposed network(by default, a ResNet18 is used) on the [cats and dogs dataset](https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip), which consists both [frozen and unfrozen stages](https://github.com/PyTorchLightning/pytorch-lightning/blob/495812878dfe2e31ec2143c071127990afbb082b/pl_examples/domain_templates/computer_vision_fine_tuning.py#L21-L35). Then, by calling `optimize()`, we can obtain all available accelaration combinations provided by BigDL-Nano for inference. By calling `get_best_mdoel()` , we could get an accelerated model whose inference is 6.5x times faster.


## Prepare the environment
Expand All @@ -25,21 +25,27 @@ pip install --upgrade numpy==1.21.6
Initialize environment variables with script `bigdl-nano-init` installed with bigdl-nano.
```
source bigdl-nano-init
unset KMP_AFFINITY
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should not have the user to manually manage this @MeouSker77 @TheaperDeng

Copy link
Contributor

@MeouSker77 MeouSker77 Sep 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems the default value of KMP_AFFINITY in bigdl-nano-init is wrong, now its default value will cause program use only one core when inferencing, so here unset it to use more cores. I'll fix its default value to use all cores by default

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be a really tricky one. KMP_AFFINITY default value has some conflict with onnxruntime's core resource control (which makes it behave strangely). We cannot unset or reset this sys variable once the user start their python script. Some solution we have used in training is creating a new process(with different KMP_AFFINITY) to handle the work, while it is not reasonable as well since creating a process will cost 1-10ms and very unfriendly to low latency requirement.

One possible solution is:

  1. make KMP_AFFINITY default value be None(that is not setting this value in bigdl-nano-init)
  2. Only set this value during multi-process training by creating new processes (since this var does not give much postive affect for inferencing or single process training)

What do you think? @jason-dai @MeouSker77

Copy link
Contributor

@MeouSker77 MeouSker77 Sep 14, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now the default value of KMP_AFFINITY is granularity=fine,compact,1,0. In my test, if remove the compact option, i.e., set KMP_AFFINITY to granularity=fine,1,0 can avoid the conflict with onnxruntime, so it's also a possible solution.

Here is the explanation of the compact option:
image

Shall we take this solution? @jason-dai @TheaperDeng

By the way, the multi-processes training will set KMP_AFFINITY for sub-processes automatically, it won't be affected by this default value.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now the default value of KMP_AFFINITY is granularity=fine,compact,1,0. In my test, if remove the compact option, i.e., set KMP_AFFINITY to granularity=fine,1,0 can avoid the conflict with onnxruntime, so it's also a possible solution.

Here is the explanation of the compact option: image

Shall we take this solution? @jason-dai @TheaperDeng

By the way, the multi-processes training will set KMP_AFFINITY for sub-processes automatically, it won't be affected by this default value.

I think that should be OK.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good; need to test with our performance suit.

```
You may find environment variables set like follows:
```
OpenMP library found...
Setting OMP_NUM_THREADS...
Setting OMP_NUM_THREADS specified for pytorch...
Setting KMP_AFFINITY...
Setting KMP_BLOCKTIME...
Setting MALLOC_CONF...
Setting LD_PRELOAD...
nano_vars.sh already exists
+++++ Env Variables +++++
LD_PRELOAD=./../lib/libjemalloc.so
MALLOC_CONF=oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms:-1,muzzy_decay_ms:-1
LD_PRELOAD=/opt/anaconda3/envs/nano/bin/../lib/libiomp5.so /opt/anaconda3/envs/nano/lib/python3.7/site-packages/bigdl/nano//libs/libtcmalloc.so
MALLOC_CONF=
OMP_NUM_THREADS=112
KMP_AFFINITY=granularity=fine,compact,1,0
KMP_BLOCKTIME=1
TF_ENABLE_ONEDNN_OPTS=
TF_ENABLE_ONEDNN_OPTS=1
ENABLE_TF_OPTS=1
NANO_TF_INTER_OP=1
+++++++++++++++++++++++++
Complete.
```
Expand All @@ -56,23 +62,28 @@ python inference_pipeline.py
```

## Results

It will take about 2 minutes to run inference optimization. Then you may find the result for inference as follows:
It will take about 1 minute to run inference optimization. Then you may find the result for inference as follows:
```
==========================Optimization Results==========================
accleration option: original, latency: 54.2669ms, accuracy: 0.9937
accleration option: fp32_ipex, latency: 40.3075ms, accuracy: 0.9937
accleration option: bf16_ipex, latency: 115.6182ms, accuracy: 0.9937
accleration option: int8, latency: 14.4857ms, accuracy: 0.4750
accleration option: jit_fp32, latency: 39.3361ms, accuracy: 0.9937
accleration option: jit_fp32_ipex, latency: 39.2949ms, accuracy: 0.9937
accleration option: jit_fp32_ipex_clast, latency: 24.5715ms, accuracy: 0.9937
accleration option: openvino_fp32, latency: 14.5771ms, accuracy: 0.9937
accleration option: openvino_int8, latency: 7.2186ms, accuracy: 0.9937
accleration option: onnxruntime_fp32, latency: 44.3872ms, accuracy: 0.9937
accleration option: onnxruntime_int8_qlinear, latency: 10.1866ms, accuracy: 0.9937
accleration option: onnxruntime_int8_integer, latency: 18.8731ms, accuracy: 0.9875
When accelerator is onnxruntime, the model with minimal latency is: inc + onnxruntime + qlinear
When accuracy drop less than 5%, the model with minimal latency is: openvino + pot
The model with minimal latency is: openvino + pot
```
-------------------------------- ---------------------- -------------- ----------------------
| method | status | latency(ms) | accuracy |
-------------------------------- ---------------------- -------------- ----------------------
| original | successful | 43.688 | 0.969 |
| fp32_ipex | successful | 33.383 | not recomputed |
| bf16 | fail to forward | None | None |
| bf16_ipex | early stopped | 203.897 | None |
| int8 | successful | 10.74 | 0.969 |
| jit_fp32 | successful | 38.732 | not recomputed |
| jit_fp32_ipex | successful | 35.205 | not recomputed |
| jit_fp32_ipex_channels_last | successful | 19.327 | not recomputed |
| openvino_fp32 | successful | 10.215 | not recomputed |
| openvino_int8 | successful | 8.192 | 0.969 |
| onnxruntime_fp32 | successful | 20.931 | not recomputed |
| onnxruntime_int8_qlinear | successful | 8.274 | 0.969 |
| onnxruntime_int8_integer | fail to convert | None | None |
-------------------------------- ---------------------- -------------- ----------------------

Optimization cost 64.3s at all.
===========================Stop Optimization===========================
When accuracy drop less than 5%, the model with minimal latency is: openvino + int8
```
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,9 @@ def accuracy(pred, target):
latency_sample_num=30)

# 4. Get the best model under specific restrictions or without restrictions
acc_model, option = optimizer.get_best_model(accelerator="onnxruntime")
print("When accelerator is onnxruntime, the model with minimal latency is: ", option)

acc_model, option = optimizer.get_best_model(accuracy_criterion=0.05)
print("When accuracy drop less than 5%, the model with minimal latency is: ", option)

acc_model, option = optimizer.get_best_model()
print("The model with minimal latency is: ", option)

# 5. Inference with accelerated model
x_input = next(iter(datamodule.train_dataloader(batch_size=1)))[0]
output = acc_model(x_input)
78 changes: 58 additions & 20 deletions python/nano/src/bigdl/nano/pytorch/inference/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,12 @@
from bigdl.nano.deps.neural_compressor.inc_api import load_inc_model, quantize as inc_quantize
from bigdl.nano.utils.inference.pytorch.model import AcceleratedLightningModule
from bigdl.nano.pytorch.utils import TORCH_VERSION_LESS_1_10
import warnings
# Filter out useless Userwarnings
warnings.filterwarnings('ignore', category=UserWarning, module='pytorch_lightning')
warnings.filterwarnings('ignore', category=DeprecationWarning, module='pytorch_lightning')
warnings.filterwarnings('ignore', category=UserWarning, module='torch')
warnings.filterwarnings('ignore', category=DeprecationWarning, module='torch')

import os
os.environ['LOGLEVEL'] = 'ERROR' # remove parital output of inc
Expand Down Expand Up @@ -167,6 +173,17 @@ def optimize(self, model: nn.Module,

model.eval() # change model to eval mode

input_sample = tuple(next(iter(training_data))[:-1])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This basically assume users have only one output and one/multiple input for the model's forward, maybe we can inspect the model and set a more accurate param num

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch, will consider this.

st = time.perf_counter()
try:
with torch.no_grad():
model(*input_sample)
except Exception:
invalidInputError(False,
"training_data is incompatible with your model input.")
exit(1)
rnwang04 marked this conversation as resolved.
Show resolved Hide resolved
baseline_time = time.perf_counter() - st

print("==========================Start Optimization==========================")
start_time = time.perf_counter()
for idx, (method, available) in enumerate(available_dict.items()):
Expand All @@ -183,7 +200,6 @@ def optimize(self, model: nn.Module,
precision: str = option.get_precision()
# if precision is fp32, then we will use trace method
if precision == "fp32":
input_sample = tuple(next(iter(training_data))[:-1])
try:
if accelerator is None and use_ipex is False:
acce_model = model
Expand Down Expand Up @@ -238,19 +254,28 @@ def func_test(model, input_sample):

torch.set_num_threads(thread_num)
try:
result_map[method]["latency"] =\
_throughput_calculate_helper(latency_sample_num, func_test,
acce_model, input_sample)
result_map[method]["latency"], status =\
_throughput_calculate_helper(latency_sample_num, baseline_time,
func_test, acce_model, input_sample)
if status is False:
result_map[method]["status"] = "early stopped"
torch.set_num_threads(default_threads)
continue
except Exception as e:
result_map[method]["status"] = "fail to forward"
torch.set_num_threads(default_threads)
continue

torch.set_num_threads(default_threads)
if self._calculate_accuracy:
result_map[method]["accuracy"] =\
_accuracy_calculate_helper(acce_model,
metric, validation_data)
# here we suppose trace don't change accuracy,
# so we jump it to reduce time cost of optimize
if precision == "fp32" and method != "original":
result_map[method]["accuracy"] = "not recomputed"
else:
result_map[method]["accuracy"] =\
_accuracy_calculate_helper(acce_model,
metric, validation_data)
else:
result_map[method]["accuracy"] = None

Expand Down Expand Up @@ -329,9 +354,11 @@ def get_best_model(self,
continue

if accuracy_criterion is not None:
accuracy: float = result["accuracy"]
accuracy = result["accuracy"]
compare_acc: float = best_metric.accuracy
if self._direction == "min":
if accuracy == "not recomputed":
pass
elif self._direction == "min":
if (accuracy - compare_acc) / compare_acc > accuracy_criterion:
continue
else:
Expand All @@ -341,7 +368,11 @@ def get_best_model(self,
# After the above conditions are met, the latency comparison is performed
if result["latency"] < best_metric.latency:
best_model = result["model"]
best_metric = CompareMetric(method, result["latency"], result["accuracy"])
if result["accuracy"] != "not recomputed":
accuracy = result["accuracy"]
else:
accuracy = self.optimized_model_dict["original"]["accuracy"]
best_metric = CompareMetric(method, result["latency"], accuracy)

return best_model, _format_acceleration_option(best_metric.method_name)

Expand Down Expand Up @@ -647,7 +678,7 @@ def _available_acceleration_combination():
return available_dict


def _throughput_calculate_helper(iterrun, func, *args):
def _throughput_calculate_helper(iterrun, baseline_time, func, *args):
'''
A simple helper to calculate average latency
'''
Expand All @@ -659,6 +690,9 @@ def _throughput_calculate_helper(iterrun, func, *args):
func(*args)
end = time.perf_counter()
time_list.append(end - st)
# if three samples cost more than 4x time than baseline model, prune it
if i == 2 and end - start_time > 12 * baseline_time:
return np.mean(time_list) * 1000, False
# at least need 10 iters and try to control calculation
# time less than 2 min
if i + 1 >= min(iterrun, 10) and (end - start_time) > 2:
Expand All @@ -667,7 +701,7 @@ def _throughput_calculate_helper(iterrun, func, *args):
time_list.sort()
# remove top and least 10% data
time_list = time_list[int(0.1 * iterrun): int(0.9 * iterrun)]
return np.mean(time_list) * 1000
return np.mean(time_list) * 1000, True


def _accuracy_calculate_helper(model, metric, data):
Expand All @@ -676,9 +710,10 @@ def _accuracy_calculate_helper(model, metric, data):
'''
metric_list = []
sample_num = 0
for i, (data_input, target) in enumerate(data):
metric_list.append(metric(model(data_input), target).numpy() * data_input.shape[0])
sample_num += data_input.shape[0]
with torch.no_grad():
for i, (data_input, target) in enumerate(data):
metric_list.append(metric(model(data_input), target).numpy() * data_input.shape[0])
sample_num += data_input.shape[0]
return np.sum(metric_list) / sample_num


Expand All @@ -690,7 +725,10 @@ def _format_acceleration_option(method_name: str) -> str:
repr_str = ""
for key, value in option.__dict__.items():
if value is True:
repr_str = repr_str + key + " + "
if key == "pot":
repr_str = repr_str + "int8" + " + "
else:
repr_str = repr_str + key + " + "
elif isinstance(value, str):
repr_str = repr_str + value + " + "
if len(repr_str) > 0:
Expand All @@ -705,9 +743,9 @@ def _format_optimize_result(optimize_result_dict: dict,
'''
if calculate_accuracy is True:
horizontal_line = " {0} {1} {2} {3}\n" \
.format("-" * 32, "-" * 22, "-" * 14, "-" * 12)
.format("-" * 32, "-" * 22, "-" * 14, "-" * 22)
repr_str = horizontal_line
repr_str += "| {0:^30} | {1:^20} | {2:^12} | {3:^10} |\n" \
repr_str += "| {0:^30} | {1:^20} | {2:^12} | {3:^20} |\n" \
.format("method", "status", "latency(ms)", "accuracy")
repr_str += horizontal_line
for method, result in optimize_result_dict.items():
Expand All @@ -716,10 +754,10 @@ def _format_optimize_result(optimize_result_dict: dict,
if latency != "None":
latency = round(latency, 3)
accuracy = result.get("accuracy", "None")
if accuracy != "None":
if accuracy != "None" and isinstance(accuracy, float):
accuracy = round(accuracy, 3)
method_str = f"| {method:^30} | {status:^20} | " \
f"{latency:^12} | {accuracy:^10} |\n"
f"{latency:^12} | {accuracy:^20} |\n"
repr_str += method_str
repr_str += horizontal_line
else:
Expand Down