From 4f04b9e7348a030532d064cdcdca2a1850b09545 Mon Sep 17 00:00:00 2001 From: Ruonan Wang <105281011+rnwang04@users.noreply.github.com> Date: Thu, 15 Sep 2022 12:09:42 +0800 Subject: [PATCH] Nano : reduce time cost for InferenceOptimizer and update demo (#5740) * update readme and reduce time cost for calculating accuracy * add prune type * filter warnings * add check for input_sample * update based on comment * update readme and add ut * update input_sample * delete redundant lines * add forward_args --- .../inference_pipeline/resnet/README.md | 56 +++++++------ .../resnet/inference_pipeline.py | 6 -- .../bigdl/nano/pytorch/inference/optimizer.py | 79 ++++++++++++++----- .../utils/inference/pytorch/model_utils.py | 18 +++-- .../tests/test_inference_pipeline_ipex.py | 28 +++++++ 5 files changed, 130 insertions(+), 57 deletions(-) diff --git a/python/nano/example/pytorch/inference_pipeline/resnet/README.md b/python/nano/example/pytorch/inference_pipeline/resnet/README.md index f4e2646f89b..537200fad00 100644 --- a/python/nano/example/pytorch/inference_pipeline/resnet/README.md +++ b/python/nano/example/pytorch/inference_pipeline/resnet/README.md @@ -1,7 +1,7 @@ # Bigdl-nano InferenceOptimizer example on Cat vs. Dog dataset This example illustrates how to apply InferenceOptimizer to quickly find acceleration method with the minimum inference latency under specific restrictions or without restrictions for a trained model. -For the sake of this example, we first train the proposed network(by default, a ResNet18 is used) on the [cats and dogs dataset](https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip), which consists both [frozen and unfrozen stages](https://github.com/PyTorchLightning/pytorch-lightning/blob/495812878dfe2e31ec2143c071127990afbb082b/pl_examples/domain_templates/computer_vision_fine_tuning.py#L21-L35). Then, by calling `optimize()`, we can obtain all available accelaration combinations provided by BigDL-Nano for inference. By calling `get_best_mdoel()` , we could get an accelerated model whose inference is 7.5x times faster. +For the sake of this example, we first train the proposed network(by default, a ResNet18 is used) on the [cats and dogs dataset](https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip), which consists both [frozen and unfrozen stages](https://github.com/PyTorchLightning/pytorch-lightning/blob/495812878dfe2e31ec2143c071127990afbb082b/pl_examples/domain_templates/computer_vision_fine_tuning.py#L21-L35). Then, by calling `optimize()`, we can obtain all available accelaration combinations provided by BigDL-Nano for inference. By calling `get_best_mdoel()` , we could get an accelerated model whose inference is 5x times faster. ## Prepare the environment @@ -28,18 +28,23 @@ source bigdl-nano-init ``` You may find environment variables set like follows: ``` +OpenMP library found... Setting OMP_NUM_THREADS... Setting OMP_NUM_THREADS specified for pytorch... Setting KMP_AFFINITY... Setting KMP_BLOCKTIME... Setting MALLOC_CONF... +Setting LD_PRELOAD... +nano_vars.sh already exists +++++ Env Variables +++++ -LD_PRELOAD=./../lib/libjemalloc.so -MALLOC_CONF=oversize_threshold:1,background_thread:true,metadata_thp:auto,dirty_decay_ms:-1,muzzy_decay_ms:-1 +LD_PRELOAD=/opt/anaconda3/envs/nano/bin/../lib/libiomp5.so /opt/anaconda3/envs/nano/lib/python3.7/site-packages/bigdl/nano//libs/libtcmalloc.so +MALLOC_CONF= OMP_NUM_THREADS=112 -KMP_AFFINITY=granularity=fine,compact,1,0 +KMP_AFFINITY=granularity=fine KMP_BLOCKTIME=1 -TF_ENABLE_ONEDNN_OPTS= +TF_ENABLE_ONEDNN_OPTS=1 +ENABLE_TF_OPTS=1 +NANO_TF_INTER_OP=1 +++++++++++++++++++++++++ Complete. ``` @@ -56,23 +61,28 @@ python inference_pipeline.py ``` ## Results - -It will take about 2 minutes to run inference optimization. Then you may find the result for inference as follows: +It will take about 1 minute to run inference optimization. Then you may find the result for inference as follows: ``` ==========================Optimization Results========================== -accleration option: original, latency: 54.2669ms, accuracy: 0.9937 -accleration option: fp32_ipex, latency: 40.3075ms, accuracy: 0.9937 -accleration option: bf16_ipex, latency: 115.6182ms, accuracy: 0.9937 -accleration option: int8, latency: 14.4857ms, accuracy: 0.4750 -accleration option: jit_fp32, latency: 39.3361ms, accuracy: 0.9937 -accleration option: jit_fp32_ipex, latency: 39.2949ms, accuracy: 0.9937 -accleration option: jit_fp32_ipex_clast, latency: 24.5715ms, accuracy: 0.9937 -accleration option: openvino_fp32, latency: 14.5771ms, accuracy: 0.9937 -accleration option: openvino_int8, latency: 7.2186ms, accuracy: 0.9937 -accleration option: onnxruntime_fp32, latency: 44.3872ms, accuracy: 0.9937 -accleration option: onnxruntime_int8_qlinear, latency: 10.1866ms, accuracy: 0.9937 -accleration option: onnxruntime_int8_integer, latency: 18.8731ms, accuracy: 0.9875 -When accelerator is onnxruntime, the model with minimal latency is: inc + onnxruntime + qlinear -When accuracy drop less than 5%, the model with minimal latency is: openvino + pot -The model with minimal latency is: openvino + pot -``` + -------------------------------- ---------------------- -------------- ---------------------- +| method | status | latency(ms) | accuracy | + -------------------------------- ---------------------- -------------- ---------------------- +| original | successful | 43.688 | 0.969 | +| fp32_ipex | successful | 33.383 | not recomputed | +| bf16 | fail to forward | None | None | +| bf16_ipex | early stopped | 203.897 | None | +| int8 | successful | 10.74 | 0.969 | +| jit_fp32 | successful | 38.732 | not recomputed | +| jit_fp32_ipex | successful | 35.205 | not recomputed | +| jit_fp32_ipex_channels_last | successful | 19.327 | not recomputed | +| openvino_fp32 | successful | 10.215 | not recomputed | +| openvino_int8 | successful | 8.192 | 0.969 | +| onnxruntime_fp32 | successful | 20.931 | not recomputed | +| onnxruntime_int8_qlinear | successful | 8.274 | 0.969 | +| onnxruntime_int8_integer | fail to convert | None | None | + -------------------------------- ---------------------- -------------- ---------------------- + +Optimization cost 64.3s at all. +===========================Stop Optimization=========================== +When accuracy drop less than 5%, the model with minimal latency is: openvino + int8 +``` \ No newline at end of file diff --git a/python/nano/example/pytorch/inference_pipeline/resnet/inference_pipeline.py b/python/nano/example/pytorch/inference_pipeline/resnet/inference_pipeline.py index aeeb2d47638..372f68dc534 100644 --- a/python/nano/example/pytorch/inference_pipeline/resnet/inference_pipeline.py +++ b/python/nano/example/pytorch/inference_pipeline/resnet/inference_pipeline.py @@ -49,15 +49,9 @@ def accuracy(pred, target): latency_sample_num=30) # 4. Get the best model under specific restrictions or without restrictions - acc_model, option = optimizer.get_best_model(accelerator="onnxruntime") - print("When accelerator is onnxruntime, the model with minimal latency is: ", option) - acc_model, option = optimizer.get_best_model(accuracy_criterion=0.05) print("When accuracy drop less than 5%, the model with minimal latency is: ", option) - acc_model, option = optimizer.get_best_model() - print("The model with minimal latency is: ", option) - # 5. Inference with accelerated model x_input = next(iter(datamodule.train_dataloader(batch_size=1)))[0] output = acc_model(x_input) diff --git a/python/nano/src/bigdl/nano/pytorch/inference/optimizer.py b/python/nano/src/bigdl/nano/pytorch/inference/optimizer.py index eabcdaa20d4..09a370d59bc 100644 --- a/python/nano/src/bigdl/nano/pytorch/inference/optimizer.py +++ b/python/nano/src/bigdl/nano/pytorch/inference/optimizer.py @@ -34,7 +34,14 @@ load_onnxruntime_model from bigdl.nano.deps.neural_compressor.inc_api import load_inc_model, quantize as inc_quantize from bigdl.nano.utils.inference.pytorch.model import AcceleratedLightningModule +from bigdl.nano.utils.inference.pytorch.model_utils import get_forward_args, get_input_example from bigdl.nano.pytorch.utils import TORCH_VERSION_LESS_1_10 +import warnings +# Filter out useless Userwarnings +warnings.filterwarnings('ignore', category=UserWarning, module='pytorch_lightning') +warnings.filterwarnings('ignore', category=DeprecationWarning, module='pytorch_lightning') +warnings.filterwarnings('ignore', category=UserWarning, module='torch') +warnings.filterwarnings('ignore', category=DeprecationWarning, module='torch') import os os.environ['LOGLEVEL'] = 'ERROR' # remove parital output of inc @@ -167,6 +174,17 @@ def optimize(self, model: nn.Module, model.eval() # change model to eval mode + forward_args = get_forward_args(model) + input_sample = get_input_example(model, training_data, forward_args) + st = time.perf_counter() + try: + with torch.no_grad(): + model(*input_sample) + except Exception: + invalidInputError(False, + "training_data is incompatible with your model input.") + baseline_time = time.perf_counter() - st + print("==========================Start Optimization==========================") start_time = time.perf_counter() for idx, (method, available) in enumerate(available_dict.items()): @@ -183,7 +201,6 @@ def optimize(self, model: nn.Module, precision: str = option.get_precision() # if precision is fp32, then we will use trace method if precision == "fp32": - input_sample = tuple(next(iter(training_data))[:-1]) try: if accelerator is None and use_ipex is False: acce_model = model @@ -238,9 +255,13 @@ def func_test(model, input_sample): torch.set_num_threads(thread_num) try: - result_map[method]["latency"] =\ - _throughput_calculate_helper(latency_sample_num, func_test, - acce_model, input_sample) + result_map[method]["latency"], status =\ + _throughput_calculate_helper(latency_sample_num, baseline_time, + func_test, acce_model, input_sample) + if status is False: + result_map[method]["status"] = "early stopped" + torch.set_num_threads(default_threads) + continue except Exception as e: result_map[method]["status"] = "fail to forward" torch.set_num_threads(default_threads) @@ -248,9 +269,14 @@ def func_test(model, input_sample): torch.set_num_threads(default_threads) if self._calculate_accuracy: - result_map[method]["accuracy"] =\ - _accuracy_calculate_helper(acce_model, - metric, validation_data) + # here we suppose trace don't change accuracy, + # so we jump it to reduce time cost of optimize + if precision == "fp32" and method != "original": + result_map[method]["accuracy"] = "not recomputed" + else: + result_map[method]["accuracy"] =\ + _accuracy_calculate_helper(acce_model, + metric, validation_data) else: result_map[method]["accuracy"] = None @@ -329,9 +355,11 @@ def get_best_model(self, continue if accuracy_criterion is not None: - accuracy: float = result["accuracy"] + accuracy = result["accuracy"] compare_acc: float = best_metric.accuracy - if self._direction == "min": + if accuracy == "not recomputed": + pass + elif self._direction == "min": if (accuracy - compare_acc) / compare_acc > accuracy_criterion: continue else: @@ -341,7 +369,11 @@ def get_best_model(self, # After the above conditions are met, the latency comparison is performed if result["latency"] < best_metric.latency: best_model = result["model"] - best_metric = CompareMetric(method, result["latency"], result["accuracy"]) + if result["accuracy"] != "not recomputed": + accuracy = result["accuracy"] + else: + accuracy = self.optimized_model_dict["original"]["accuracy"] + best_metric = CompareMetric(method, result["latency"], accuracy) return best_model, _format_acceleration_option(best_metric.method_name) @@ -647,7 +679,7 @@ def _available_acceleration_combination(): return available_dict -def _throughput_calculate_helper(iterrun, func, *args): +def _throughput_calculate_helper(iterrun, baseline_time, func, *args): ''' A simple helper to calculate average latency ''' @@ -659,6 +691,9 @@ def _throughput_calculate_helper(iterrun, func, *args): func(*args) end = time.perf_counter() time_list.append(end - st) + # if three samples cost more than 4x time than baseline model, prune it + if i == 2 and end - start_time > 12 * baseline_time: + return np.mean(time_list) * 1000, False # at least need 10 iters and try to control calculation # time less than 2 min if i + 1 >= min(iterrun, 10) and (end - start_time) > 2: @@ -667,7 +702,7 @@ def _throughput_calculate_helper(iterrun, func, *args): time_list.sort() # remove top and least 10% data time_list = time_list[int(0.1 * iterrun): int(0.9 * iterrun)] - return np.mean(time_list) * 1000 + return np.mean(time_list) * 1000, True def _accuracy_calculate_helper(model, metric, data): @@ -676,9 +711,10 @@ def _accuracy_calculate_helper(model, metric, data): ''' metric_list = [] sample_num = 0 - for i, (data_input, target) in enumerate(data): - metric_list.append(metric(model(data_input), target).numpy() * data_input.shape[0]) - sample_num += data_input.shape[0] + with torch.no_grad(): + for i, (data_input, target) in enumerate(data): + metric_list.append(metric(model(data_input), target).numpy() * data_input.shape[0]) + sample_num += data_input.shape[0] return np.sum(metric_list) / sample_num @@ -690,7 +726,10 @@ def _format_acceleration_option(method_name: str) -> str: repr_str = "" for key, value in option.__dict__.items(): if value is True: - repr_str = repr_str + key + " + " + if key == "pot": + repr_str = repr_str + "int8" + " + " + else: + repr_str = repr_str + key + " + " elif isinstance(value, str): repr_str = repr_str + value + " + " if len(repr_str) > 0: @@ -705,9 +744,9 @@ def _format_optimize_result(optimize_result_dict: dict, ''' if calculate_accuracy is True: horizontal_line = " {0} {1} {2} {3}\n" \ - .format("-" * 32, "-" * 22, "-" * 14, "-" * 12) + .format("-" * 32, "-" * 22, "-" * 14, "-" * 22) repr_str = horizontal_line - repr_str += "| {0:^30} | {1:^20} | {2:^12} | {3:^10} |\n" \ + repr_str += "| {0:^30} | {1:^20} | {2:^12} | {3:^20} |\n" \ .format("method", "status", "latency(ms)", "accuracy") repr_str += horizontal_line for method, result in optimize_result_dict.items(): @@ -716,10 +755,10 @@ def _format_optimize_result(optimize_result_dict: dict, if latency != "None": latency = round(latency, 3) accuracy = result.get("accuracy", "None") - if accuracy != "None": + if accuracy != "None" and isinstance(accuracy, float): accuracy = round(accuracy, 3) method_str = f"| {method:^30} | {status:^20} | " \ - f"{latency:^12} | {accuracy:^10} |\n" + f"{latency:^12} | {accuracy:^20} |\n" repr_str += method_str repr_str += horizontal_line else: diff --git a/python/nano/src/bigdl/nano/utils/inference/pytorch/model_utils.py b/python/nano/src/bigdl/nano/utils/inference/pytorch/model_utils.py index 9003a752c62..d2e5ec2e616 100644 --- a/python/nano/src/bigdl/nano/utils/inference/pytorch/model_utils.py +++ b/python/nano/src/bigdl/nano/utils/inference/pytorch/model_utils.py @@ -13,7 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # -from typing import Any +from typing import Any, Sequence from bigdl.nano.pytorch.lightning import LightningModule import inspect from torch.utils.data import DataLoader @@ -29,10 +29,11 @@ def get_forward_args(model): return forward_args -def get_input_example(model, input_sample): +def get_input_example(model, input_sample, forward_args): if isinstance(input_sample, DataLoader): - # TODO: This assumpe the last output is y - input_sample = tuple(next(iter(input_sample))[:-1]) + input_sample = next(iter(input_sample)) + if isinstance(input_sample, Sequence): + input_sample = tuple(list(input_sample)[:len(forward_args)]) elif input_sample is None: if getattr(model, "example_input_array", None) is not None: input_sample = model.example_input_array @@ -43,8 +44,9 @@ def get_input_example(model, input_sample): model.val_dataloader]: try: dataloader = dataloader_fn() - # TODO: This assumpe the last output is y - input_sample = tuple(next(iter(dataloader)))[:-1] + input_sample = next(iter(input_sample)) + if isinstance(input_sample, Sequence): + input_sample = tuple(list(input_sample)[:len(forward_args)]) break except Exception as _e: pass @@ -73,13 +75,13 @@ def export_to_onnx(model, input_sample=None, onnx_path="model.onnx", dynamic_axe :param dynamic_axes: If we set the first dim of each input as a dynamic batch_size :param **kwargs: will be passed to torch.onnx.export function. ''' - input_sample = get_input_example(model, input_sample) + forward_args = get_forward_args(model) + input_sample = get_input_example(model, input_sample, forward_args) invalidInputError(input_sample is not None, 'You should implement at least one of model.test_dataloader, ' 'model.train_dataloader, model.val_dataloader and ' 'model.predict_dataloader, ' 'or set one of input_sample and model.example_input_array') - forward_args = get_forward_args(model) if dynamic_axes: dynamic_axes = {} for arg in forward_args: diff --git a/python/nano/test/pytorch/tests/test_inference_pipeline_ipex.py b/python/nano/test/pytorch/tests/test_inference_pipeline_ipex.py index 1d939aa7541..e2845f49d31 100644 --- a/python/nano/test/pytorch/tests/test_inference_pipeline_ipex.py +++ b/python/nano/test/pytorch/tests/test_inference_pipeline_ipex.py @@ -111,3 +111,31 @@ def test_pipeline_without_metric(self): error_msg = e.value.args[0] assert error_msg == "If you want to specify accuracy_criterion, you need "\ "to set metric and validation_data when call 'optimize'." + + def test_summary(self): + inference_opt = InferenceOptimizer() + with pytest.raises(RuntimeError) as e: + inference_opt.summary() + error_msg = e.value.args[0] + assert error_msg == "There is no optimization result. You should call .optimize() "\ + "before summary()" + inference_opt.optimize(model=self.model, + training_data=self.train_loader, + thread_num=1) + inference_opt.summary() + + def test_wrong_data_loader(self): + fake_transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), + transforms.Resize(64), + ]) + fake_train_loader = create_data_loader(self.data_dir, 32, self.num_workers, + fake_transform, subset=10, shuffle=True) + inference_opt = InferenceOptimizer() + with pytest.raises(RuntimeError) as e: + inference_opt.optimize(model=self.model, + training_data=fake_train_loader, + thread_num=1) + error_msg = e.value.args[0] + assert error_msg == "training_data is incompatible with your model input."