diff --git a/.azure-pipelines/scripts/codeScan/pylint/pylint.sh b/.azure-pipelines/scripts/codeScan/pylint/pylint.sh index 80212d56bf7..307b32440df 100644 --- a/.azure-pipelines/scripts/codeScan/pylint/pylint.sh +++ b/.azure-pipelines/scripts/codeScan/pylint/pylint.sh @@ -35,7 +35,11 @@ pip install torch==1.12.0 \ accelerate \ flask==2.1.3 \ xgboost \ - datasets + datasets \ + prettytable \ + psutil \ + py-cpuinfo \ + pyyaml if [ "${scan_module}" = "neural_solution" ]; then cd /neural-compressor diff --git a/.azure-pipelines/scripts/ut/3x/run_3x_tf.sh b/.azure-pipelines/scripts/ut/3x/run_3x_tf.sh index 495a0fa31f8..1743af5cdbd 100644 --- a/.azure-pipelines/scripts/ut/3x/run_3x_tf.sh +++ b/.azure-pipelines/scripts/ut/3x/run_3x_tf.sh @@ -14,6 +14,9 @@ inc_path=$(python -c 'import neural_compressor; print(neural_compressor.__path__ cd /neural-compressor/test || exit 1 find ./3x/tensorflow/* -name "test*.py" | sed 's,\.\/,coverage run --source='"${inc_path}"' --append ,g' | sed 's/$/ --verbose/'> run.sh find ./3x/common/* -name "test*.py" | sed 's,\.\/,coverage run --source='"${inc_path}"' --append ,g' | sed 's/$/ --verbose/'>> run.sh +sed -i '/tensorflow\/keras\//d' run.sh + +find ./3x/tensorflow/keras/* -name "test*.py" | sed 's,\.\/,coverage run --source='"${inc_path}"' --append ,g' | sed 's/$/ --verbose/'> run_keras.sh LOG_DIR=/neural-compressor/log_dir mkdir -p ${LOG_DIR} @@ -22,8 +25,13 @@ ut_log_name=${LOG_DIR}/ut_3x_tf.log echo "cat run.sh..." sort run.sh -o run.sh cat run.sh | tee ${ut_log_name} +echo "cat run_keras.sh..." +sort run_keras.sh -o run_keras.sh +cat run_keras.sh | tee ${ut_log_name} echo "------UT start-------" bash -x run.sh 2>&1 | tee -a ${ut_log_name} +pip install intel-extension-for-tensorflow[cpu] +bash -x run_keras.sh 2>&1 | tee -a ${ut_log_name} cp .coverage ${LOG_DIR}/.coverage echo "------UT end -------" diff --git a/neural_compressor/tensorflow/__init__.py b/neural_compressor/tensorflow/__init__.py index 11705f145ed..e1e987f0a0d 100644 --- a/neural_compressor/tensorflow/__init__.py +++ b/neural_compressor/tensorflow/__init__.py @@ -12,6 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from neural_compressor.tensorflow.utils import register_algo -from neural_compressor.tensorflow.algorithms import static_quantize_entry -from neural_compressor.tensorflow.quantization import quantize_model, StaticQuantConfig, get_default_static_quant_config +from neural_compressor.tensorflow.utils import register_algo, Model +from neural_compressor.tensorflow.quantization import ( + quantize_model, + StaticQuantConfig, + SmoothQuantConfig, + get_default_sq_config, + get_default_static_quant_config, +) diff --git a/neural_compressor/tensorflow/algorithms/__init__.py b/neural_compressor/tensorflow/algorithms/__init__.py index bd274c80c18..c48b6ae5332 100644 --- a/neural_compressor/tensorflow/algorithms/__init__.py +++ b/neural_compressor/tensorflow/algorithms/__init__.py @@ -13,4 +13,5 @@ # limitations under the License. -from neural_compressor.tensorflow.algorithms.static_quantize import static_quantize_entry +from neural_compressor.tensorflow.algorithms.smoother import SmoothQuant +from neural_compressor.tensorflow.algorithms.static_quant import KerasAdaptor diff --git a/neural_compressor/tensorflow/algorithms/smoother/__init__.py b/neural_compressor/tensorflow/algorithms/smoother/__init__.py new file mode 100644 index 00000000000..fd91d612289 --- /dev/null +++ b/neural_compressor/tensorflow/algorithms/smoother/__init__.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from neural_compressor.tensorflow.algorithms.smoother.core import SmoothQuant +from neural_compressor.tensorflow.algorithms.smoother.scaler import ( + SmoothQuantScaler, + SmoothQuantScalerLLM, +) +from neural_compressor.tensorflow.algorithms.smoother.calibration import ( + SmoothQuantCalibration, + SmoothQuantCalibrationLLM, +) diff --git a/neural_compressor/tensorflow/algorithms/smoother/calibration.py b/neural_compressor/tensorflow/algorithms/smoother/calibration.py new file mode 100644 index 00000000000..48bb19b5d7a --- /dev/null +++ b/neural_compressor/tensorflow/algorithms/smoother/calibration.py @@ -0,0 +1,503 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tensorflow model calibration process for Smooth Quantization.""" + +import copy +import logging +import os +import tempfile +import time +from collections import OrderedDict, UserDict + +import numpy as np +import tensorflow as tf +from tensorflow.core.framework import attr_value_pb2, graph_pb2 +from tensorflow.python.framework import dtypes, tensor_util +from tensorflow.python.saved_model import load, tag_constants + +from neural_compressor.tensorflow.quantization.utils.graph_util import GraphAnalyzer +from neural_compressor.tensorflow.quantization.utils.graph_util import GraphRewriterHelper as Helper +from neural_compressor.tensorflow.quantization.utils.quantize_graph_common import QuantizeGraphHelper +from neural_compressor.tensorflow.quantization.utils.utility import ( + iterator_sess_run, + parse_saved_model, + reconstruct_saved_model, +) +from neural_compressor.tensorflow.utils import CaptureOutputToFile, TensorflowLLMModel + +logger = logging.getLogger("neural_compressor") +debug = bool(logger.level == logging.DEBUG) + + +class SmoothQuantCalibration: + """A class for performing smooth quantization calibration on a Tensorflow model. + + Args: + model (Model): The Tensorflow wrapper model to be calibrated. + dataloader (DataLoader): The data loader for the calibration dataset. + iterations (int): The number of iterations to run the calibration process. + op_types (List[str]): The types of operations to be quantized. + percentile (float): The percentile of calibration to remove outliers. + """ + + def __init__(self, model, dataloader, iterations, op_types, percentile): + """Initializes a SmoothQuantCalibration object.""" + self.model = model + self.dataloader = dataloader + self.iterations = iterations + # self.iterations = 3 + self.op_types = op_types + self.percentile = percentile + self._sq_input_node_names = [] + self._sq_output_tensor_dict = {} + self._sq_weight_node_names = {} # mapping from its weight node name to the concrete output node name + + def _inference_for_calibration(self, model): + """Run the calibration on the input graph. + + Args: + model(TensorflowBaseModel): input TensorflowBaseModel + """ + # ITEX optimization has broken INC calibration process. + # INC needs turn off ITEX optimization pass in calibration stage. + # TODO ITEX will provide API to replace setting environment variable. + os.environ["ITEX_REMAPPER"] = "0" + sess = model.sess + iter_op = model.iter_op + input_tensor = model.input_tensor + output_tensor = [item + ":0" for item in self._sq_input_node_names] + # TF table initialization: https://github.com/tensorflow/tensorflow/issues/8665 + node_names = [node.name for node in sess.graph.as_graph_def().node] + if "init_all_tables" in node_names: # pragma: no cover + init_table_op = sess.graph.get_operation_by_name("init_all_tables") + sess.run(init_table_op) + + logger.info("Start sampling on calibration dataset for Smooth Quantization.") + if hasattr(self.dataloader, "__len__") and len(self.dataloader) == 0: # pragma: no cover + feed_dict = {} + for output_idx, output in enumerate( + sess.run(output_tensor, feed_dict) + if iter_op == [] + else iterator_sess_run(sess, iter_op, feed_dict, output_tensor, self.iterations) + ): + self._sq_output_tensor_dict.setdefault(self._sq_input_node_names[output_idx], []).append(output) + for idx, (inputs, labels) in enumerate(self.dataloader): + if len(input_tensor) == 1: + feed_dict = {} + if ( + isinstance(inputs, dict) or isinstance(inputs, OrderedDict) or isinstance(inputs, UserDict) + ): # pragma: no cover + for name in inputs: + for tensor in input_tensor: + pos = tensor.name.rfind(":") + t_name = tensor.name if pos < 0 else tensor.name[:pos] + if name == t_name: + feed_dict[tensor] = inputs[name] + break + else: + feed_dict = {input_tensor[0]: inputs} # get raw tensor using index [0] + else: # pragma: no cover + assert len(input_tensor) == len(inputs), "inputs len must equal with input_tensor" + feed_dict = {} + if isinstance(inputs, dict) or isinstance(inputs, OrderedDict) or isinstance(inputs, UserDict): + for name in inputs: + for tensor in input_tensor: + pos = tensor.name.rfind(":") + t_name = tensor.name if pos < 0 else tensor.name[:pos] + if name in [tensor.name, t_name]: + feed_dict[tensor] = inputs[name] + break + else: + # sometimes the input_tensor is not the same order with inputs + # we should check and pair them + def check_shape(tensor, data): + # scalar or 1 dim default True + if ( + tensor.shape is None + or tensor.shape.dims is None + or len(tensor.shape.dims) == 1 + or not hasattr(data, "shape") + ): + return True + tensor_shape = tuple(tensor.shape) + data_shape = tuple(data.shape) + for tensor_dim, data_dim in zip(tensor_shape, data_shape): + if tensor_dim is not None and tensor_dim != data_dim: + return False + return True + + disorder_tensors = [] + disorder_inputs = [] + for idx, sort_tensor in enumerate(input_tensor): + sort_input = inputs[idx] + if check_shape(sort_tensor, sort_input): + feed_dict.update({sort_tensor: sort_input}) + else: + disorder_tensors.append(sort_tensor) + disorder_inputs.append(sort_input) + for i, dis_tensor in enumerate(disorder_tensors): + for j, dis_input in enumerate(disorder_inputs): + if check_shape(dis_tensor, dis_input): + feed_dict.update({dis_tensor: dis_input}) + break + for output_idx, output in enumerate( + sess.run(output_tensor, feed_dict) + if iter_op == [] + else iterator_sess_run(sess, iter_op, feed_dict, output_tensor, self.iterations) + ): + self._sq_output_tensor_dict.setdefault(self._sq_input_node_names[output_idx], []).append(output) + if idx + 1 == self.iterations: + break + os.environ["ITEX_REMAPPER"] = "1" + + def _generate_calibration_data(self): + """Generate the calibration data.""" + sorted_graph = QuantizeGraphHelper().get_sorted_graph( + self.model.graph_def, self.model.input_node_names, self.model.output_node_names + ) + + for node in sorted_graph.node: + if node.op not in self.op_types: + continue + # Fix retval already been set issue + if "while" in node.input[0]: # pragma: no cover + continue + self._sq_input_node_names.append(node.input[0]) + self._sq_weight_node_names[node.input[1]] = node.name + + self._inference_for_calibration(self.model) + + def _get_maxval_per_channel(self, tensor_data, percentile): + """Get the max values per input channel. + + Args: + tensor_data: The input tensors + percentile: The percentile of calibration to remove outliers + + Returns: + The max values per input channel + """ + permute_datas = [] + for data in tensor_data: # iteration_num * (N, H, W, C) + if len(data.shape) == 3: # pragma: no cover + # TODO matmul batchsize*seq*inchannel + tensor = np.abs(np.reshape(data, (-1, data.shape[-1]))) + permute_datas.append(tensor) + elif len(data.shape) == 4: # already NHWC + # tensor = np.transpose(data, [0, 3, 1, 2]) + tensor = data + tensor = np.abs(np.reshape(tensor, (-1, tensor.shape[-1]))) + permute_datas.append(tensor) + elif len(data.shape) == 2: # (?, ic) + permute_datas.append(np.abs(data)) + else: # pragma: no cover + assert False, "not supported" + permute_datas = np.concatenate(permute_datas, axis=0) + permute_datas = permute_datas.reshape(-1, permute_datas.shape[-1]) + # try: + # np.percentile(permute_datas, percentile, axis=0) + # except FloatingPointError: + # indexes = [i for i,e in enumerate(np.percentile(permute_datas, percentile, axis=0)) if np.isnan(e)][0] + # np.seterr(all='warning') + max_per_channels = np.percentile(permute_datas, percentile, axis=0) + # max_per_channels = np.max(permute_datas, axis=0) + max_per_channels = max_per_channels.astype(np.single) + return max_per_channels + + def __call__(self): + """Generates calibration data and calculate the maximum values per channel. + + Returns: + max_vals_per_channel (dict): A dictionary containing the maximum values per channel. + sq_weight_node_names (dict): A dictionary mapping from weight names to target node names. + """ + self._generate_calibration_data() + max_vals_per_channel = {} + for key in self._sq_output_tensor_dict.keys(): + max_val_per_channel = self._get_maxval_per_channel( + self._sq_output_tensor_dict[key], percentile=self.percentile + ) + max_vals_per_channel[key] = max_val_per_channel + return max_vals_per_channel, self._sq_weight_node_names + + +class SmoothQuantCalibrationLLM(SmoothQuantCalibration): + """A class for performing smooth quantization calibration on a Tensorflow LLM model. + + Args: + model (str): A path to the original Tensorflow model. + iterations (int): The number of iterations to run the calibration process. + op_types (List[str]): The types of operations to be quantized. + percentile (float): The percentile of calibration to remove outliers. + eval_func (function): The function to inference the model. + temp_path (str): The temporary path to store median model. + weight_name_mapping (): A function that convert weight tensor name in autotrackable to node name in graph_def + """ + + def __init__(self, model_path, dataloader, iterations, op_types, percentile, temp_path, weight_name_mapping): + """Initializes a SmoothQuantCalibrationLLM object.""" + self.func = None + self.graph_def = None + self.frozen_func = None + self._saved_model = None + self.model = model_path + self.dataloader = dataloader + self.iterations = iterations + self.op_types = op_types + self.percentile = percentile + self.temp_path = temp_path + self.weight_name_mapping = weight_name_mapping + self.print_node_list = [] + self._sq_input_node_names = [] + self._sq_target_node_names = {} + self._sq_output_tensor_dict = {} + self._sq_weight_tensor_dict = {} + + def _parse_calibration_logs(self, tmp_dump_file): + """Parse calibration logs for llm saved_model.""" + valid_data = [] + with open(tmp_dump_file) as file: + for i in file.readlines(): + if i.startswith(";"): + valid_data.append(i.strip()) + + for activation in valid_data: + [key, value] = activation.rsplit(":") + activation_name = key[1:-9] + import json + + value = value.replace(" ", ",") + value = value.replace("][", "],[") + data = json.loads(value) + if activation_name not in self._sq_output_tensor_dict: + self._sq_output_tensor_dict[activation_name] = [np.array(data)] + else: + self._sq_output_tensor_dict[activation_name].append(np.array(data)) + + def _insert_print_for_activation(self, graph_def): + """Insert print node in the graph to do the calibration for llm saved_model.""" + cur_graph = GraphAnalyzer() + cur_graph.graph = graph_def + + graph_info = cur_graph.parse_graph() + for cur_list in self.print_node_list: + pre_node_name = cur_list[0] + post_node_name = cur_list[-1] + insert_node_pairs = [] + top_node = graph_info[pre_node_name].node + if top_node.op == "ConcatV2": + for i in range(top_node.attr["N"].i): + insert_node_pairs.append([top_node.input[i], post_node_name]) + elif top_node.op in ("BatchMatMul", "BatchMatMulV2"): + insert_node_pairs.append([top_node.input[0], post_node_name]) + if graph_info[top_node.input[1]].node.op != "Const": + insert_node_pairs.append([top_node.input[1], post_node_name]) + elif top_node.op in ("Conv2DBackpropInput", "Conv3DBackpropInputV2"): + insert_node_pairs.append([top_node.input[2], post_node_name]) + else: + refresh_pre_node_name = graph_info[pre_node_name].node.input[0] + # Check the Conv2D could be fused with previous Pad or not. + # If so, we need to update the pre-node name correspondingly. + refresh_pre_node = graph_info[Helper.node_name_from_input(refresh_pre_node_name)].node + if refresh_pre_node.op == "Pad" and top_node.op in ("Conv2D", "Conv3D"): + insert_node_pairs.append([refresh_pre_node_name, post_node_name]) + refresh_pre_node_name = refresh_pre_node.input[0] + + insert_node_pairs.append([refresh_pre_node_name, post_node_name]) + + output_names = [] + for node_pair_names in insert_node_pairs: + for index, each_node_name in enumerate(node_pair_names): + name_with_sig = each_node_name + node_name_prefix = name_with_sig.replace(":", "__port__").replace("^", "__hat__") + print_node = Helper.create_node( + "Print", + node_name_prefix + "_print__{}".format(index), + [each_node_name + ":0", each_node_name + ":0"], + ) + + if index == 0: + msg = ";{}__print__:".format(each_node_name) + # workaround for swish_f32, attribute T is not in the op definition + if "swish_f32" in graph_info[pre_node_name].node.name: + src_dt = attr_value_pb2.AttrValue(type=dtypes.float32.as_datatype_enum) + else: + src_dt = graph_info[pre_node_name].node.attr["T"] + else: + break + + print_node.attr["T"].CopyFrom(src_dt) + + print_node.attr["message"].s = msg.encode() + print_node.attr["first_n"].i = -1 + print_node.attr["summarize"].i = 102400000 + + attr_u = [dtypes.as_dtype(src_dt.type).as_datatype_enum] + print_node.attr["U"].list.CopyFrom(attr_value_pb2.AttrValue.ListValue(type=attr_u)) + post_node_names = graph_info[Helper.node_name_from_input(each_node_name)].outputs + if post_node_names: + for post_node_name in post_node_names: + post_node = graph_info[post_node_name].node + if each_node_name not in post_node.input: + continue + if ( + post_node.op == "FusedBatchNormV3" + and "_print_identity" + not in graph_info[Helper.node_name_from_input(post_node.name)].node.input[0] + ): + identity_node = Helper.create_node( + "Identity", + post_node.name + "_print_identity", + [graph_info[Helper.node_name_from_input(post_node.name)].node.input[0]], + ) + identity_node.attr["T"].CopyFrom(src_dt) + cur_graph.add_node( + identity_node, + graph_info[Helper.node_name_from_input(post_node.name)].node.input[0], + [post_node.name], + ) + identity_node.input.append("^" + print_node.name) + else: + post_node.input.append("^" + print_node.name) + + cur_graph.add_node(print_node, each_node_name, []) + else: + identity_node1 = Helper.create_node( + "Identity", print_node.name + "_identity", [print_node.name] + ) + identity_node1.attr["T"].CopyFrom(src_dt) + cur_graph.add_node(print_node, each_node_name, [identity_node1.name]) + cur_graph.add_node(identity_node1, print_node.name, []) + output_names.append(identity_node1.name) + + return cur_graph.dump_graph() + + def evaluate(self, model): + """Evaluate function that inference the model to apply calibration. + + Args: + model (tf.python.trackable.autotrackable): The model to be evaluated. + The object is usually gotten by using tf.saved_model.load(model_dir) API. + + Returns: + accuracy (float): The accuracy result. + """ + input_tensor_names = model.input_tensor_names + auto_trackable = model.model + infer = auto_trackable.signatures["serving_default"] + for idx, (inputs, _) in enumerate(self.dataloader): + feed_dict = {} + if len(input_tensor_names) == 1: + feed_dict[input_tensor_names[0]] = inputs + else: + assert len(input_tensor_names) == len(inputs), "inputs len must equal with input_tensor" + for i, input_tensor_name in enumerate(input_tensor_names): + feed_dict[input_tensor_name] = inputs[i] + + _ = infer(**feed_dict) + + if idx >= self.iterations: + break + + def _inference(self, sampling_graph_def): + """Inference the model to apply calibration. + + Args: + sampling_graph_def: The temporary graph_def for inference. + """ + logger.info("Start sampling on calibration dataset for Smooth Quantization.") + # reconstruct graph_def that inserted print node to saved_model + reconstruct_saved_model(sampling_graph_def, self.func, self.frozen_func, self._saved_model, self.temp_path) + model = TensorflowLLMModel(self.temp_path) + + input_tensor_names = model.input_tensor_names + auto_trackable = model.model + infer = auto_trackable.signatures["serving_default"] + for idx, (inputs, _) in enumerate(self.dataloader): + feed_dict = {} + if len(input_tensor_names) == 1: + feed_dict[input_tensor_names[0]] = inputs + else: + assert len(input_tensor_names) == len(inputs), "inputs len must equal with input_tensor" + for i, input_tensor_name in enumerate(input_tensor_names): + feed_dict[input_tensor_name] = inputs[i] + + _ = infer(**feed_dict) + + if idx >= self.iterations: + break + + def _inference_for_calibration(self, model): + """Run the calibration on the input graph.""" + sampling_graph_def = self._insert_print_for_activation(model) + tmp_dump_file = tempfile.mkstemp(suffix=".log")[1] + with CaptureOutputToFile(tmp_dump_file): + self._inference(sampling_graph_def) + self._parse_calibration_logs(tmp_dump_file) + del sampling_graph_def + + def _get_weight_tensors(self): + model = load.load(self.model, [tag_constants.SERVING]) + for weight_tensor in model.variables: + parsed_name = self.weight_name_mapping(weight_tensor.name) + if parsed_name in self._sq_target_node_names: + self._sq_weight_tensor_dict[parsed_name] = weight_tensor.numpy() + + assert len(self._sq_weight_tensor_dict) == len( + self._sq_target_node_names + ), "Failed to get weights for some nodes, please check variables" + + def _generate_calibration_data(self, input_node_names, output_node_names): + """Generate the calibration data.""" + sorted_graph = QuantizeGraphHelper().get_sorted_graph( + self.graph_def, + input_node_names, + output_node_names, + ) + + for node in sorted_graph.node: + if node.op not in self.op_types: + continue + # Fix retval already been set issue + if "while" in node.input[0]: # pragma: no cover + continue + self._sq_input_node_names.append(node.input[0]) + self.print_node_list.append([node.name]) + self._sq_target_node_names[node.input[1]] = node.name + self._get_weight_tensors() + sampling_graph_def = copy.deepcopy(self.graph_def) + self._inference_for_calibration(sampling_graph_def) + + def __call__(self, input_node_names, output_node_names): + """Generates calibration data and calculate the maximum values per channel. + + Args: + input_node_names: (list): A list of names for input nodes. + output_node_names: (list): A list of names for output nodes. + + Returns: + max_vals_per_channel (dict): A dictionary containing the maximum values per channel. + sq_target_node_names (dict): A dictionary mapping from weight names to target node names. + sq_weight_tensor_dict (dict): A dictionary containing tensor of weights. + """ + self.graph_def, self._saved_model, self.func, self.frozen_func, _, _ = parse_saved_model(self.model) + self._generate_calibration_data(input_node_names, output_node_names) + max_vals_per_channel = {} + for activation_name, output_tensor in self._sq_output_tensor_dict.items(): + max_val_per_channel = self._get_maxval_per_channel(output_tensor, percentile=self.percentile) + max_vals_per_channel[activation_name] = max_val_per_channel + return max_vals_per_channel, self._sq_target_node_names, self._sq_weight_tensor_dict, self.graph_def diff --git a/neural_compressor/tensorflow/algorithms/smoother/core.py b/neural_compressor/tensorflow/algorithms/smoother/core.py new file mode 100644 index 00000000000..425b05bdcca --- /dev/null +++ b/neural_compressor/tensorflow/algorithms/smoother/core.py @@ -0,0 +1,191 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Callable, Dict + +import tensorflow as tf + +from neural_compressor.common import logger +from neural_compressor.common.utils import DEFAULT_WORKSPACE +from neural_compressor.tensorflow.algorithms.smoother.calibration import ( + SmoothQuantCalibration, + SmoothQuantCalibrationLLM, +) +from neural_compressor.tensorflow.algorithms.smoother.scaler import SmoothQuantScaler, SmoothQuantScalerLLM +from neural_compressor.tensorflow.quantization.config import SmoothQuantConfig +from neural_compressor.tensorflow.quantization.utils.graph_util import GraphAnalyzer +from neural_compressor.tensorflow.utils import SPR_BASE_VERSIONS, BaseModel, TensorflowLLMModel, framework_specific_info + + +class SmoothQuant: + """The class that performs smooth quantization.""" + + def __init__( + self, + config: SmoothQuantConfig, + calib_dataloader: Callable, + calib_iteration: int = 1, + ): + """Convert the model by smooth quant. + + Args: + config: the SmoothQuantConfig class used to set this class + calibdataloader: the calibration dataloader + calib_iteration: how many steps of iterations on the dataloader to move forward + + Returns: + model: A smoothed Tensorflow model + """ + self.config = config + self.calib_dataloader = calib_dataloader + self.calib_iteration = calib_iteration + + self.new_api = tf.version.VERSION in SPR_BASE_VERSIONS + self.device = framework_specific_info["device"] + self.itex_mode = framework_specific_info["backend"] == "itex" + + for _, value in self.config.items(): + single_config = value + break + + self.alpha = single_config.alpha + self.folding = single_config.folding + self.percentile = single_config.percentile + self.op_types = single_config.op_types + self.scales_per_op = single_config.scales_per_op + self.record_max_info = single_config.record_max_info + self.weight_clip = single_config.weight_clip + self.auto_alpha_args = single_config.auto_alpha_args + + def get_weight_from_input_tensor(self, model, input_tensor_names): + """Extracts weight tensors and their associated nodes from a smooth quant node's input tensor. + + Args: + model: A TensorFlow model containing a `graph_def` attribute. + input_tensor_names: A list of input tensor names to search for weight tensors. + + Returns: + A tuple of two dictionaries: + - sq_weight_tensors: A dictionary mapping each input tensor name + to a dict of its associated weight tensors with weight name. + - sq_weights_nodes: A dictionary mapping each input tensor name + to a dict of its associated weight nodes with weight name. + """ + g_analyzer = GraphAnalyzer() + g_analyzer.graph = model.graph_def + graph_info = g_analyzer.parse_graph() + + sq_weight_tensors = {} + sq_weights_nodes = {} + + from tensorflow.python.framework import tensor_util + + for name in input_tensor_names: + # Use dict rather than list to fix the QKV/VQK misorder issue + curr_weight_tensors = {} + curr_weights_nodes = {} + next_node_names = graph_info[name].outputs + for node_name in next_node_names: + curr_node = graph_info[node_name].node + if curr_node.op not in self.op_types: + continue + if len(curr_node.input) >= 2: + weight_name = curr_node.input[1] + weight_node = graph_info[weight_name].node + weight_tensor = tensor_util.MakeNdarray(weight_node.attr["value"].tensor) + curr_weight_tensors[weight_name] = weight_tensor + curr_weights_nodes[weight_name] = weight_node + # {input node -> {xxx_q_proj_matmul: value1, xxx_v_proj_matmul: value2, ...}, ...} + sq_weight_tensors[name] = curr_weight_tensors + sq_weights_nodes[name] = curr_weights_nodes + return sq_weight_tensors, sq_weights_nodes + + def apply_smooth_quant(self, model: BaseModel): + """Apply smooth quant to the model.""" + logger.info("Start Smoothing process for Smooth Quantization.") + + # Do a pre-optimization before smooth quant + from neural_compressor.tensorflow.quantization.utils.graph_rewriter.generic.pre_optimize import PreOptimization + + pre_optimizer_handle = PreOptimization(model, self.new_api, self.device) + pre_optimized_model = pre_optimizer_handle.get_optimized_model(self.itex_mode) + model.graph_def = pre_optimized_model.graph_def + + # Run calibration to get max values per channel + + calibration = SmoothQuantCalibration( + model, self.calib_dataloader, self.calib_iteration, self.op_types, self.percentile + ) + max_vals_per_channel, sq_weight_node_names = calibration() + + # Get weight tensors and weight nodes based on the input tensor + sq_weight_tensors, sq_weights_nodes = self.get_weight_from_input_tensor(model, max_vals_per_channel.keys()) + + # Calculate the smooth quant scaler and insert Mul op into the graph + scaler = SmoothQuantScaler(model, self.calib_dataloader, self.alpha, self.scales_per_op) + model, mul_list = scaler.transform( + max_vals_per_channel, sq_weight_tensors, sq_weights_nodes, sq_weight_node_names + ) + + return model + + def apply_smooth_quant_LLM(self, model: BaseModel): + """Apply smooth quant to the LLM model.""" + # Do a pre-optimization before smooth quant + from neural_compressor.tensorflow.quantization.utils.graph_rewriter.generic.pre_optimize import PreOptimization + + pre_optimizer_handle = PreOptimization(model, self.new_api, self.device) + pre_optimized_model = pre_optimizer_handle.get_optimized_model(self.itex_mode) + model.graph_def = pre_optimized_model.graph_def + + llm_temp_dir = DEFAULT_WORKSPACE + "/temp_saved_model" + # Run calibration to get max values per channel + calibration = SmoothQuantCalibrationLLM( + model._model, + self.calib_dataloader, + self.calib_iteration, + self.op_types, + self.percentile, + llm_temp_dir, + model.weight_name_mapping, + ) + max_vals_per_channel, sq_target_node_names, sq_weight_tensor_dict, sq_graph_def = calibration( + model.input_node_names, model.output_node_names + ) + + # Calculate the smooth quant scaler and insert Mul op into the graph + scaler = SmoothQuantScalerLLM(sq_graph_def, self.alpha, self.scales_per_op, self.op_types) + sq_graph_def, sq_weight_scale_dict, mul_list = scaler.transform( + max_vals_per_channel, sq_weight_tensor_dict, sq_target_node_names + ) + model.graph_def = sq_graph_def + model.model_path = llm_temp_dir + model.sq_weight_scale_dict = sq_weight_scale_dict + return model + + def __call__(self, model: BaseModel): + """Convert the model by smooth quant. + + Args: + model: original model + + Returns: + model: A smoothed Tensorflow model + """ + apply_func = self.apply_smooth_quant_LLM if isinstance(model, TensorflowLLMModel) else self.apply_smooth_quant + + return apply_func(model) diff --git a/neural_compressor/tensorflow/algorithms/smoother/scaler.py b/neural_compressor/tensorflow/algorithms/smoother/scaler.py new file mode 100644 index 00000000000..839c5e2332a --- /dev/null +++ b/neural_compressor/tensorflow/algorithms/smoother/scaler.py @@ -0,0 +1,268 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tensorflow scaling model weights and activations for Smooth Quantization.""" + +import logging + +import numpy as np +import tensorflow as tf +from tensorflow.python.framework import dtypes, tensor_util + +from neural_compressor.tensorflow.quantization.utils.graph_util import GraphAnalyzer + +logger = logging.getLogger("neural_compressor") + + +class SmoothQuantScaler: + """A class for scaling model weights using Smooth Quantization method. + + Args: + model: Tensorflow model to be scaled + dataloader: Tensorflow dataloader for the dataset + alpha: float, the scaling factor + scales_per_op: bool, each op will have an individual scale or + ops with the same input will share a scale + """ + + def __init__(self, model, dataloader, alpha, scales_per_op): + """Initialization.""" + self.model = model + self.dataloader = dataloader + self.alpha = alpha + self.scales_per_op = scales_per_op + self.mul_list = [] + self.g_analyzer = GraphAnalyzer() + self.g_analyzer.graph = self.model + + def _adjust_activation(self, scale, input_node_name, output_node_name, w_i): + """Insert the Mul node after the activation before the weight node. + + Args: + scale: smooth scale with the shape (ic,) + input_node_name: the parent input node + output_node_name: the concrete output weight node name + w_i: distinguish between different output weight nodes on different branches when naming + """ + from neural_compressor.tensorflow.quantization.utils.graph_util import GraphRewriterHelper as Helper + + node_suffix = str(w_i) + mul_const_node = Helper.create_constant_node(input_node_name + "/scale_mul" + node_suffix, scale, tf.float32) + mul_node = Helper.create_node( + "Mul", + input_node_name + "_mul" + node_suffix, + [input_node_name + "/scale_mul" + node_suffix, input_node_name], + ) + Helper.set_attr_dtype(mul_node, "T", dtypes.float32) + self.mul_list.append(mul_node.name) + self.g_analyzer.add_node(mul_node, input_node_name, [output_node_name]) + self.g_analyzer.add_node(mul_const_node, None, [input_node_name + "_mul" + node_suffix]) + + def _adjust_weight(self, scale, weight_node, original_weight): + """In-place adjust weight by scale. + + Args: + scale: smooth scale with the shape (ic,) + weight_node: reference to the original const weight node + original_weight: numpy value of the original const weight node + """ + # scale: (ic,) + original_shape = original_weight.shape + if len(original_shape) == 4: # (fh, hw, ic, oc) + W = np.transpose(original_weight, [0, 1, 3, 2]) # move input channel to last dimension + W *= scale + W = np.transpose(W, [0, 1, 3, 2]) # move input channel back + weight_node.attr["value"].tensor.CopyFrom(tensor_util.make_tensor_proto(W)) + elif len(original_shape) == 2: # (ic, oc) if transpose_a == transpose_b == false + W = np.transpose(original_weight, [1, 0]) + W *= scale + W = np.transpose(W, [1, 0]) + weight_node.attr["value"].tensor.CopyFrom(tensor_util.make_tensor_proto(W)) + + def transform(self, max_vals_per_channel, sq_weight_tensors, sq_weights_nodes, sq_weight_node_names): + """Apply scaling to weights and activations based on the maximum values per channel. + + Args: + max_vals_per_channel (dict): A dictionary containing the maximum values per channel for each input node. + sq_weight_tensors (dict): A dictionary containing the name -> weight tensors mapping for each input node. + sq_weights_nodes (dict): A dictionary containing the name -> constant nodes mapping for each input node. + sq_weight_node_names (dict): A dictionary from weight node name to the its concrete output node name. + + Returns: + tuple: A tuple containing the modified model and a list of the inserted multiplication nodes. + """ + logger.info("Start scaling on model graph for Smooth Quantization.") + if self.scales_per_op: + # 1. obtain the smooth scale per op + # 2. adjust weight + # 3. adjust activation + for idx, input_node_name in enumerate(max_vals_per_channel): + A_max_per_in_channel = max_vals_per_channel[input_node_name] + W_dict = sq_weight_tensors[input_node_name] + # Use the const nodes before to get weight values + W_const_node_dict = sq_weights_nodes[input_node_name] + # Get the concrete weight node as the output of Mul insertion + for w_i, W_name in enumerate(W_dict): + W = W_dict[W_name] + if len(W.shape) == 4: + # https://www.tensorflow.org/api_docs/python/tf/nn/conv2d + # weight: [filter_height, filter_width, in_channels, out_channels] + # activation: NHWC, also batch_shape + [in_height, in_width, in_channels] + tensor = np.abs(np.transpose(W, [0, 1, 3, 2])) + # reduce weight max to (in_channel, ), aligned with activation max + W_max_per_in_channel = np.max(np.reshape(tensor, (-1, tensor.shape[-1])), axis=0) + elif len(W.shape) == 2: # matmul + # reduce weight max to (in_channel, ), aligned with activation max + tensor = np.abs(W) + W_max_per_in_channel = np.max(tensor, axis=1) + else: # pragma: no cover + assert False, "not supported" + cur_const_node = W_const_node_dict[W_name] + try: + scale = np.power(A_max_per_in_channel, self.alpha) / np.power( + W_max_per_in_channel, (1 - self.alpha) + ) + except ValueError as e: # pragma: no cover + logger.info(e) + logger.info("Skip smoothing the node: {}".format(cur_const_node.name)) + continue + # clip the scales that are too small + scale = np.clip(scale, a_min=1e-5, a_max=1e8) + # skip smoothing the op where scale has elements that less than 1 + # if np.any(scale < 1): + # logger.info("skip smooth quant: {}".format(input_node_name)) + # continue + self._adjust_weight(scale, cur_const_node, W) + self._adjust_activation(1 / scale, input_node_name, sq_weight_node_names[cur_const_node.name], w_i) + else: + pass + sq_graph_def = self.g_analyzer.dump_graph() + sq_graph_def.library.CopyFrom(self.model.graph_def.library) + self.model.graph_def = sq_graph_def + return self.model, self.mul_list + + +class SmoothQuantScalerLLM(SmoothQuantScaler): + """A class for scaling model weights for TF LLM models using Smooth Quantization method. + + Args: + graph_def: graph_def of the model to be scaled + alpha: float, the scaling factor + scales_per_op: bool, each op will have an individual scale or + ops with the same input will share a scale + op_types: + """ + + def __init__(self, graph_def, alpha, scales_per_op, op_types): + """Initialization.""" + self.graph_def = graph_def + self.alpha = alpha + self.scales_per_op = scales_per_op + self.op_types = op_types + + self.graph_info = None + self.mul_list = [] + self.sq_weight_scale_dict = {} + + def _parse_weight_dict(self, max_vals_per_channel, sq_weight_tensor_dict): + """Parse weight related dictionaries to two required dictionaries. + + Args: + max_vals_per_channel (dict): A dictionary containing the maximum values per channel. + sq_weight_tensor_dict (dict): A dictionary containing tensor of weights. + + Returns: + sq_weight_tensors: A dictionary whose structure is like {input_node_name: weight_tensor}}. + sq_weights_node_names: A dictionary whose structure is like {input_node_name: weight_node_name}}. + """ + sq_weight_tensors = {} + sq_weight_node_names = {} + for input_node_name in max_vals_per_channel: + curr_weight_tensors = [] + curr_weights_node_names = [] + next_node_names = self.graph_info[input_node_name].outputs + for node_name in next_node_names: + curr_node = self.graph_info[node_name].node + if curr_node.op not in self.op_types: + continue + if len(curr_node.input) >= 2: + weight_name = curr_node.input[1] + weight_tensor = sq_weight_tensor_dict[weight_name] + curr_weight_tensors.append(weight_tensor) + curr_weights_node_names.append(weight_name) + sq_weight_tensors[input_node_name] = curr_weight_tensors + sq_weight_node_names[input_node_name] = curr_weights_node_names + return sq_weight_tensors, sq_weight_node_names + + def transform(self, max_vals_per_channel, sq_weight_tensor_dict, sq_target_node_names): + """Apply scaling to weights and activations based on the maximum values per channel. + + Args: + max_vals_per_channel (dict): A dictionary containing the maximum values per channel for each input node. + sq_weight_tensor_dict (dict): A dictionary whose structure is like {input_node_name: weight_tensor}. + sq_target_node_names (dict): A dictionary whose structure is like {weight_node_name: target_node_name}. + """ + self.g_analyzer = GraphAnalyzer() + self.g_analyzer.graph = self.graph_def + self.graph_info = self.g_analyzer.parse_graph() + sq_weight_tensors, sq_weight_node_names = self._parse_weight_dict(max_vals_per_channel, sq_weight_tensor_dict) + logger.info("Start scaling on model graph for Smooth Quantization.") + if self.scales_per_op: + # 1. obtain the smooth scale per op + # 2. adjust weight + # 3. adjust activation + for _, input_node_name in enumerate(max_vals_per_channel): + activation_max_per_in_channel = max_vals_per_channel[input_node_name] + W_lst = sq_weight_tensors[input_node_name] # VQK weight value + # Use the const nodes before to get weight values, VQK ReadVariable + W_node_name_lst = sq_weight_node_names[input_node_name] + # Get the concrete weight node as the output of Mul insertion, QKV ReadVariable + for w_i, W in enumerate(W_lst): + if len(W.shape) == 4: + # https://www.tensorflow.org/api_docs/python/tf/nn/conv2d + # weight: [filter_height, filter_width, in_channels, out_channels] + # activation: NHWC, also batch_shape + [in_height, in_width, in_channels] + tensor = np.abs(np.transpose(W, [0, 1, 3, 2])) + # reduce weight max to (in_channel, ), aligned with activation max + W_max_per_in_channel = np.max(np.reshape(tensor, (-1, tensor.shape[-1])), axis=0) + elif len(W.shape) == 2: # matmul + # reduce weight max to (in_channel, ), aligned with activation max + tensor = np.abs(W) + W_max_per_in_channel = np.max(tensor, axis=1) + else: # pragma: no cover + assert False, "not supported" + cur_weight_node_name = W_node_name_lst[w_i] + try: + scale = np.power(activation_max_per_in_channel, self.alpha) / np.power( + W_max_per_in_channel, (1 - self.alpha) + ) + except ValueError as e: # pragma: no cover + logger.info(e) + logger.info("Skip smoothing the node: {}".format(cur_weight_node_name)) + continue + # clip the scales that are too small + scale = np.clip(scale, a_min=1e-5, a_max=1e8) + # skip smoothing the op where scale has elements that less than 1 + # if np.any(scale < 1): + # logger.info("skip smooth quant: {}".format(input_node_name)) + # continue + self.sq_weight_scale_dict[cur_weight_node_name] = scale + self._adjust_activation(1 / scale, input_node_name, sq_target_node_names[cur_weight_node_name], w_i) + else: + pass + sq_graph_def = self.g_analyzer.dump_graph() + sq_graph_def.library.CopyFrom(self.graph_def.library) + return sq_graph_def, self.sq_weight_scale_dict, self.mul_list diff --git a/neural_compressor/tensorflow/algorithms/static_quantize/__init__.py b/neural_compressor/tensorflow/algorithms/static_quant/__init__.py similarity index 84% rename from neural_compressor/tensorflow/algorithms/static_quantize/__init__.py rename to neural_compressor/tensorflow/algorithms/static_quant/__init__.py index 1c5872791ad..44282f1f2e2 100644 --- a/neural_compressor/tensorflow/algorithms/static_quantize/__init__.py +++ b/neural_compressor/tensorflow/algorithms/static_quant/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from neural_compressor.tensorflow.algorithms.static_quantize.quantize_entry import static_quantize_entry +from neural_compressor.tensorflow.algorithms.static_quant.keras import KerasAdaptor diff --git a/neural_compressor/tensorflow/algorithms/static_quantize/keras.py b/neural_compressor/tensorflow/algorithms/static_quant/keras.py similarity index 92% rename from neural_compressor/tensorflow/algorithms/static_quantize/keras.py rename to neural_compressor/tensorflow/algorithms/static_quant/keras.py index 3268f0779fb..79caa0f9ef2 100644 --- a/neural_compressor/tensorflow/algorithms/static_quantize/keras.py +++ b/neural_compressor/tensorflow/algorithms/static_quant/keras.py @@ -20,17 +20,17 @@ import math import os from collections import OrderedDict, UserDict +from typing import Callable, Dict import keras import numpy as np import tensorflow as tf import yaml -from neural_compressor.common import Logger +from neural_compressor.common import logger +from neural_compressor.tensorflow.quantization.config import StaticQuantConfig from neural_compressor.tensorflow.utils import deep_get, dump_elapsed_time -logger = Logger().get_logger() - def _add_supported_quantized_objects(custom_objects): """Map all the quantized objects.""" @@ -294,16 +294,21 @@ def fuse_conv_bn(conv_weight, bn_weight, conv_type="Conv2D", eps=1.0e-5): return fused_model @dump_elapsed_time("Pass quantize model") - def quantize(self, tune_cfg, model, dataloader, q_func=None): + def quantize(self, quant_config, model, dataloader, iteration, q_func=None): """Execute the quantize process on the specified model. Args: - tune_cfg(dict): The chosen tuning configuration. + tune_cfg(dict): The user defined 'StaticQuantConfig' class. model (object): The model to do quantization. - dataloader(object): The dataloader used to load quantization dataset. + dataloader(object): The calibration dataloader used to load quantization dataset. + iteration(int): The iteration of calibration. q_func (optional): training function for quantization aware training mode. """ + self.query_fw_capability(model) + converter = KerasConfigConverter(quant_config, iteration) + tune_cfg = converter.parse_to_tune_cfg() self.tuning_cfg_to_fw(tune_cfg) + # just convert the input model to mixed_bfloat16 if self.bf16_ops and not self.quantize_config["op_wise_config"]: converted_model = self.convert_bf16() @@ -581,6 +586,8 @@ def query_fw_capability(self, model): Args: model (object): The model to query quantization tuning capability. """ + if not isinstance(model, tf.keras.Model): + model = model.model fp32_config = {"weight": {"dtype": "fp32"}, "activation": {"dtype": "fp32"}} bf16_config = {"weight": {"dtype": "bf16"}, "activation": {"dtype": "bf16"}} int8_type = self.query_handler.get_op_types_by_precision(precision="int8") @@ -744,3 +751,57 @@ def get_op_types_by_precision(self, precision): """ assert precision in list(self.cur_config["ops"].keys()) return self.cur_config["ops"][precision] + + +class KerasConfigConverter: + """Convert `StaticQuantConfig` to the format used by static quant algo.""" + + support_int8_weight = {"Dense", "Conv2d", "DepthwiseConv2D", "SeparableConv2D"} + + def __init__(self, quant_config: StaticQuantConfig, calib_iteration: int): + """Init parser for keras static quant config. + + Args: + quant_config: the keras static quant config. + calib_iteration: the iteration of calibration. + """ + self.quant_config = quant_config + self.calib_iteration = calib_iteration + + def update_config(self, quant_config, op_key): + """Update op-wise config. + + Args: + quant_config: the keras static quant config. + op_key: a tuple such as (layer type, layer name). + """ + op_value = {"activation": {}} + op_value["activation"].update( + { + "dtype": quant_config.act_dtype, + "quant_mode": "static", + "scheme": ("sym" if quant_config.act_sym else "asym"), + "granularity": quant_config.act_granularity, + "algorithm": "minmax", + } + ) + if op_key[1] not in self.support_int8_weight: + return op_value + + op_value["weight"] = { + "dtype": quant_config.weight_dtype, + "scheme": "sym" if quant_config.weight_sym else "asym", + "granularity": quant_config.weight_granularity, + "algorithm": "minmax", + } + return op_value + + def parse_to_tune_cfg(self) -> Dict: + """The function that parses StaticQuantConfig to keras tuning config.""" + tune_cfg = {"op": OrderedDict()} + for op_key, config in self.quant_config.items(): + op_value = self.update_config(config, op_key) + tune_cfg["op"].update({op_key: op_value}) + tune_cfg["calib_iteration"] = self.calib_iteration + + return tune_cfg diff --git a/neural_compressor/tensorflow/algorithms/static_quantize/keras.yaml b/neural_compressor/tensorflow/algorithms/static_quant/keras.yaml similarity index 100% rename from neural_compressor/tensorflow/algorithms/static_quantize/keras.yaml rename to neural_compressor/tensorflow/algorithms/static_quant/keras.yaml diff --git a/neural_compressor/tensorflow/algorithms/static_quantize/keras_utils/__init__.py b/neural_compressor/tensorflow/algorithms/static_quant/keras_utils/__init__.py similarity index 100% rename from neural_compressor/tensorflow/algorithms/static_quantize/keras_utils/__init__.py rename to neural_compressor/tensorflow/algorithms/static_quant/keras_utils/__init__.py diff --git a/neural_compressor/tensorflow/algorithms/static_quantize/keras_utils/conv2d.py b/neural_compressor/tensorflow/algorithms/static_quant/keras_utils/conv2d.py similarity index 100% rename from neural_compressor/tensorflow/algorithms/static_quantize/keras_utils/conv2d.py rename to neural_compressor/tensorflow/algorithms/static_quant/keras_utils/conv2d.py diff --git a/neural_compressor/tensorflow/algorithms/static_quantize/keras_utils/dense.py b/neural_compressor/tensorflow/algorithms/static_quant/keras_utils/dense.py similarity index 100% rename from neural_compressor/tensorflow/algorithms/static_quantize/keras_utils/dense.py rename to neural_compressor/tensorflow/algorithms/static_quant/keras_utils/dense.py diff --git a/neural_compressor/tensorflow/algorithms/static_quantize/keras_utils/depthwise_conv2d.py b/neural_compressor/tensorflow/algorithms/static_quant/keras_utils/depthwise_conv2d.py similarity index 100% rename from neural_compressor/tensorflow/algorithms/static_quantize/keras_utils/depthwise_conv2d.py rename to neural_compressor/tensorflow/algorithms/static_quant/keras_utils/depthwise_conv2d.py diff --git a/neural_compressor/tensorflow/algorithms/static_quantize/keras_utils/pool2d.py b/neural_compressor/tensorflow/algorithms/static_quant/keras_utils/pool2d.py similarity index 100% rename from neural_compressor/tensorflow/algorithms/static_quantize/keras_utils/pool2d.py rename to neural_compressor/tensorflow/algorithms/static_quant/keras_utils/pool2d.py diff --git a/neural_compressor/tensorflow/algorithms/static_quantize/keras_utils/quantizer.py b/neural_compressor/tensorflow/algorithms/static_quant/keras_utils/quantizer.py similarity index 100% rename from neural_compressor/tensorflow/algorithms/static_quantize/keras_utils/quantizer.py rename to neural_compressor/tensorflow/algorithms/static_quant/keras_utils/quantizer.py diff --git a/neural_compressor/tensorflow/algorithms/static_quantize/keras_utils/separable_conv2d.py b/neural_compressor/tensorflow/algorithms/static_quant/keras_utils/separable_conv2d.py similarity index 100% rename from neural_compressor/tensorflow/algorithms/static_quantize/keras_utils/separable_conv2d.py rename to neural_compressor/tensorflow/algorithms/static_quant/keras_utils/separable_conv2d.py diff --git a/neural_compressor/tensorflow/algorithms/static_quantize/quantize_entry.py b/neural_compressor/tensorflow/algorithms/static_quantize/quantize_entry.py deleted file mode 100644 index a76bf0b0e82..00000000000 --- a/neural_compressor/tensorflow/algorithms/static_quantize/quantize_entry.py +++ /dev/null @@ -1,124 +0,0 @@ -# Copyright (c) 2023 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from collections import OrderedDict -from typing import Callable, Dict - -import tensorflow as tf - -from neural_compressor.common.utils import STATIC_QUANT -from neural_compressor.tensorflow.algorithms.static_quantize.keras import KerasAdaptor -from neural_compressor.tensorflow.quantization.config import StaticQuantConfig -from neural_compressor.tensorflow.utils import register_algo - -framework_specific_info = { - "device": "cpu", - "backend": "itex", - "approach": "post_training_static_quant", -} - -support_int8_weight = {"Dense", "Conv2d", "DepthwiseConv2D", "SeparableConv2D"} - -support_int8_activation = { - "Dense", - "Conv2d", - "DepthwiseConv2D", - "SeparableConv2D", - "AvgPool2D", - "AveragePooling2D", - "MaxPool2D", - "MaxPooling2D", -} - - -def update_config(op_value: Dict, quant_config: StaticQuantConfig, layer_class: str): - """Update op-wise config from global config or operator name config or operator type config.""" - op_value["activation"].update( - { - "dtype": quant_config.act_dtype, - "quant_mode": "static", - "scheme": ("sym" if quant_config.act_sym else "asym"), - "granularity": quant_config.act_granularity, - "algorithm": "minmax", - } - ) - if layer_class not in support_int8_weight: - return - op_value["weight"] = { - "dtype": quant_config.weight_dtype, - "scheme": "sym" if quant_config.weight_sym else "asym", - "granularity": quant_config.weight_granularity, - "algorithm": "minmax", - } - - -def parse_to_keras_tune_cfg(model: tf.keras.Model, quant_config: StaticQuantConfig, calib_iteration: int) -> Dict: - """The function that parses StaticQuantConfig to keras tuning config. - - Args: - model: a fp32 model to be quantized. - quant_config: a quantization configuration. - calib_iteration: the iteration of calibration. - - Returns: - tune_cfg: the tuning config for keras adaptor. - """ - tune_cfg = {"op": OrderedDict()} - for layer in model.layers: - layer_class = layer.__class__.__name__ - if layer_class not in support_int8_activation: - continue - op_key = (layer.name, layer_class) - op_value = {"activation": {}} - - local_config = None - # priority local > global - if quant_config.local_config and layer.name in quant_config.local_config.keys(): - local_config = quant_config.local_config[layer.name] - - if local_config: - update_config(op_value, local_config, layer_class) - else: - update_config(op_value, quant_config, layer_class) - - tune_cfg["op"].update({op_key: op_value}) - tune_cfg["calib_iteration"] = calib_iteration - - return tune_cfg - - -@register_algo(name=STATIC_QUANT) -def static_quantize_entry( - model: tf.keras.Model, - quant_config: StaticQuantConfig, - calib_dataloader: Callable = None, - calib_iteration: int = 100, -) -> tf.keras.Model: - """The main entry to apply static quantization. - - Args: - model: a fp32 model to be quantized. - quant_config: a quantization configuration. - calib_dataloader: a data loader for calibration. - calib_iteration: the iteration of calibration. - - Returns: - q_model: the quantized model. - """ - - keras_adaptor = KerasAdaptor(framework_specific_info) - keras_adaptor.query_fw_capability(model) - tune_cfg = parse_to_keras_tune_cfg(model, quant_config, calib_iteration) - q_model = keras_adaptor.quantize(tune_cfg, model, calib_dataloader) - return q_model diff --git a/neural_compressor/tensorflow/quantization/__init__.py b/neural_compressor/tensorflow/quantization/__init__.py index 2b0ee3b19ff..c79e8933ee0 100644 --- a/neural_compressor/tensorflow/quantization/__init__.py +++ b/neural_compressor/tensorflow/quantization/__init__.py @@ -13,4 +13,10 @@ # limitations under the License. from neural_compressor.tensorflow.quantization.quantize import quantize_model -from neural_compressor.tensorflow.quantization.config import StaticQuantConfig, get_default_static_quant_config +from neural_compressor.tensorflow.quantization.algorithm_entry import static_quantize_entry, smooth_quant_entry +from neural_compressor.tensorflow.quantization.config import ( + StaticQuantConfig, + SmoothQuantConfig, + get_default_sq_config, + get_default_static_quant_config, +) diff --git a/neural_compressor/tensorflow/quantization/algorithm_entry.py b/neural_compressor/tensorflow/quantization/algorithm_entry.py new file mode 100644 index 00000000000..b79f1caccfd --- /dev/null +++ b/neural_compressor/tensorflow/quantization/algorithm_entry.py @@ -0,0 +1,63 @@ +# Copyright (c) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Callable, Dict + +import tensorflow as tf + +from neural_compressor.common.utils import SMOOTH_QUANT, STATIC_QUANT +from neural_compressor.tensorflow.algorithms import KerasAdaptor +from neural_compressor.tensorflow.quantization.config import SmoothQuantConfig, StaticQuantConfig +from neural_compressor.tensorflow.utils import BaseModel, KerasModel, framework_specific_info, register_algo + + +@register_algo(name=STATIC_QUANT) +def static_quantize_entry( + model: BaseModel, + quant_config: StaticQuantConfig, + calib_dataloader: Callable = None, + calib_iteration: int = 100, +): + """The main entry to apply static quantization. + + Args: + model: a fp32 model to be quantized. + quant_config: a quantization configuration. + calib_dataloader: a data loader for calibration. + calib_iteration: the iteration of calibration. + + Returns: + q_model: the quantized model. + """ + keras_adaptor = KerasAdaptor(framework_specific_info) + q_model = keras_adaptor.quantize(quant_config, model, calib_dataloader, calib_iteration) + return q_model + + +@register_algo(name=SMOOTH_QUANT) +def smooth_quant_entry( + model: BaseModel, + smooth_quant_config: SmoothQuantConfig, + calib_dataloader: Callable = None, + calib_iteration: int = 100, +): + assert not isinstance(model, KerasModel), "INC don't support smooth quantization for Keras models now." + + from neural_compressor.tensorflow.algorithms import SmoothQuant + + converter = SmoothQuant(smooth_quant_config, calib_dataloader, calib_iteration) + sq_model = converter(model) + + return sq_model diff --git a/neural_compressor/tensorflow/quantization/config.py b/neural_compressor/tensorflow/quantization/config.py index 02dee41687c..89b46918694 100644 --- a/neural_compressor/tensorflow/quantization/config.py +++ b/neural_compressor/tensorflow/quantization/config.py @@ -18,19 +18,21 @@ from __future__ import annotations from enum import Enum -from typing import Callable, Dict, List, NamedTuple, Optional, Union +from typing import Callable, Dict, List, NamedTuple, Optional, Tuple, Union import tensorflow as tf +from neural_compressor.common import logger from neural_compressor.common.base_config import ( + DEFAULT_WHITE_LIST, + OP_NAME_OR_MODULE_TYPE, BaseConfig, config_registry, register_config, register_supported_configs_for_fwk, ) -from neural_compressor.common.utils import DEFAULT_WHITE_LIST, OP_NAME_OR_MODULE_TYPE, STATIC_QUANT - -FRAMEWORK_NAME = "keras" +from neural_compressor.common.utils import SMOOTH_QUANT, STATIC_QUANT +from neural_compressor.tensorflow.utils import DEFAULT_SQ_ALPHA_ARGS class OperatorConfig(NamedTuple): @@ -39,7 +41,7 @@ class OperatorConfig(NamedTuple): valid_func_list: List[Callable] = [] -@register_config(framework_name=FRAMEWORK_NAME, algo_name=STATIC_QUANT) +@register_config(framework_name="keras", algo_name=STATIC_QUANT) class StaticQuantConfig(BaseConfig): """Config class for keras static quantization.""" @@ -108,30 +110,40 @@ def register_supported_configs(cls) -> List[OperatorConfig]: supported_configs.append(OperatorConfig(config=static_quant_config, operators=operators)) cls.supported_configs = supported_configs + @staticmethod + def get_model_info(model) -> List[Tuple[str, Callable]]: + white_list = [ + "Dense", + "Conv2d", + "DepthwiseConv2D", + "SeparableConv2D", + "AvgPool2D", + "AveragePooling2D", + "MaxPool2D", + "MaxPooling2D", + ] + filter_result = [] + + for layer in model.model.layers: + if layer.__class__.__name__ in white_list: + pair = (layer.name, layer.__class__.__name__) + filter_result.append(pair) + logger.debug(f"Get model info: {filter_result}") + return filter_result + @classmethod - def get_config_set_for_tuning( - cls, - ) -> Union[None, "StaticQuantConfig", List["StaticQuantConfig"]]: # pragma: no cover + def get_config_set_for_tuning(cls) -> Union[None, "StaticQuantConfig", List["StaticQuantConfig"]]: # TODO fwk owner needs to update it. return StaticQuantConfig(weight_sym=[True, False]) -register_supported_configs_for_fwk(fwk_name=FRAMEWORK_NAME) +register_supported_configs_for_fwk(fwk_name="keras") def get_all_registered_configs() -> Dict[str, BaseConfig]: """Get all registered configs for keras framework.""" registered_configs = config_registry.get_cls_configs() - return registered_configs.get(FRAMEWORK_NAME, {}) - - -def parse_config_from_dict(config_dict: Dict) -> BaseConfig: - """Generate a BaseConfig instance from a dict.""" - keras_registered_configs = get_all_registered_configs() - for key, val in config_dict.items(): - if key in keras_registered_configs: - config = keras_registered_configs[key].from_dict(val) - return config + return registered_configs.get("keras", {}) def get_default_static_quant_config() -> StaticQuantConfig: @@ -141,3 +153,94 @@ def get_default_static_quant_config() -> StaticQuantConfig: the default keras config. """ return StaticQuantConfig() + + +@register_config(framework_name="tensorflow", algo_name=SMOOTH_QUANT) +class SmoothQuantConfig(BaseConfig): + """Config class for tf smooth quantization.""" + + supported_configs: List[OperatorConfig] = [] + params_list = [ + "alpha", + "folding", + "percentile", + "op_types", + "scales_per_op", + "record_max_info", + "weight_clip", + "auto_alpha_args", + ] + name = SMOOTH_QUANT + + def __init__( + self, + alpha: float = 0.5, + folding: bool = False, + percentile: float = 99.999, + op_types: list = ["MatMul", "Conv2D"], + scales_per_op: bool = True, + record_max_info: bool = False, + weight_clip: bool = True, + auto_alpha_args: Dict = DEFAULT_SQ_ALPHA_ARGS, + white_list: Optional[List[OP_NAME_OR_MODULE_TYPE]] = DEFAULT_WHITE_LIST, + ): + """Init smooth quantization config. + + Args: + alpha (float or str): alpha value to balance the quantization difficulty of activation and weight. + folding (bool): whether fold those foldable Mul which are inserted for smooth quant. + percentile (float): percentile of calibration to remove outliers + op_types (list): the op type to be smooth quantized. + scales_per_op (bool): True, each op will have an individual scale, mainlyfor accuracy. + False, ops with the same input will share a scale, mainly for performance. + record_max_info (bool): whether record the max info in model for alpha tuning. + weight_clip (bool): whether to clip weight when calculating scales; by default it is on. + auto_alpha_args (dict): settings for alpha tuning. + """ + super().__init__() + self.alpha = alpha + self.folding = folding + self.percentile = percentile + self.op_types = op_types + self.scales_per_op = scales_per_op + self.record_max_info = record_max_info + self.weight_clip = weight_clip + self.auto_alpha_args = auto_alpha_args + self.white_list = white_list + self._post_init() + + @classmethod + def register_supported_configs(cls) -> List[OperatorConfig]: + supported_configs = [] + smooth_quant_config = SmoothQuantConfig() + operators = ["MatMul", "Conv2D"] + supported_configs.append(OperatorConfig(config=smooth_quant_config, operators=operators)) + cls.supported_configs = supported_configs + + @staticmethod + def get_model_info(model) -> List[Tuple[str, Callable]]: + white_list = ["MatMul", "Conv2D"] + filter_result = [] + for node in model.graph_def.node: + if node.op in white_list: + pair = (node.name, node.op) + filter_result.append(pair) + logger.debug(f"Get model info: {filter_result}") + return filter_result + + @classmethod + def get_config_set_for_tuning(cls) -> Union[None, "SmoothQuantConfig", List["SmoothQuantConfig"]]: + # TODO fwk owner needs to update it. + return SmoothQuantConfig(alpha=0.5) + + +SmoothQuantConfig.register_supported_configs() + + +def get_default_sq_config() -> SmoothQuantConfig: + """Generate the default rtn config. + + Returns: + the default smooth quant config. + """ + return SmoothQuantConfig() diff --git a/neural_compressor/tensorflow/quantization/quantize.py b/neural_compressor/tensorflow/quantization/quantize.py index 0e20b5fb221..1f375e133e1 100644 --- a/neural_compressor/tensorflow/quantization/quantize.py +++ b/neural_compressor/tensorflow/quantization/quantize.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023 Intel Corporation +# Copyright (c) 2024 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,22 +12,26 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Callable +from typing import Any, Callable, Dict, Tuple, Union import tensorflow as tf -from neural_compressor.common import Logger -from neural_compressor.common.base_config import BaseConfig +from neural_compressor.common import logger +from neural_compressor.common.base_config import BaseConfig, ComposableConfig, config_registry from neural_compressor.common.utils import STATIC_QUANT -from neural_compressor.tensorflow.quantization.config import parse_config_from_dict -from neural_compressor.tensorflow.utils import algos_mapping +from neural_compressor.tensorflow.utils import BaseModel, KerasModel, Model, algos_mapping -logger = Logger().get_logger() + +def need_apply(configs_mapping: Dict[Tuple[str, callable], BaseConfig], algo_name): + return any(config.name == algo_name for config in configs_mapping.values()) def quantize_model( - model: tf.keras.Model, quant_config: BaseConfig, calib_dataloader: Callable = None, calib_iteration: int = 100 -) -> tf.keras.Model: + model: Union[str, tf.keras.Model, BaseModel], + quant_config: BaseConfig, + calib_dataloader: Callable = None, + calib_iteration: int = 100, +): """The main entry to quantize model. Args: @@ -39,20 +43,24 @@ def quantize_model( Returns: q_model: the quantized model. """ + q_model = Model(model) + framework_name = "keras" if isinstance(q_model, KerasModel) else "tensorflow" + registered_configs = config_registry.get_cls_configs() if isinstance(quant_config, dict): - quant_config = parse_config_from_dict(quant_config) - logger.info("Parsed dict to construct the quantization config.") + quant_config = ComposableConfig.from_dict(quant_config, config_registry=registered_configs[framework_name]) + logger.info(f"Parsed a config dict to construct the quantization config: {quant_config}.") else: assert isinstance( quant_config, BaseConfig - ), "Please pass a dict or config instance as the quantization configuration." + ), f"Please pass a dict or config instance as the quantization configuration, but got {type(quant_config)}." logger.info(f"Quantize model with config: \n {quant_config.to_json_string()} \n") - # select quantization algo according to config - # TODO (Yi) support combine more than one algo - if quant_config.name == STATIC_QUANT: - quant_fn = algos_mapping[quant_config.name] - else: - raise NotImplementedError("Currently, only the basic algorithm is being ported.") - qmodel = quant_fn(model, quant_config, calib_dataloader, calib_iteration) - return qmodel + + model_info = quant_config.get_model_info(model=q_model) + configs_mapping = quant_config.to_config_mapping(model_info=model_info) + logger.debug(configs_mapping) + for algo_name, algo_func in algos_mapping.items(): + if need_apply(configs_mapping, algo_name): + logger.info(f"Start to apply {algo_name} on the model.") + q_model = algo_func(q_model, configs_mapping, calib_dataloader, calib_iteration) + return q_model diff --git a/neural_compressor/tensorflow/quantization/utils/__init__.py b/neural_compressor/tensorflow/quantization/utils/__init__.py new file mode 100644 index 00000000000..28f108cb636 --- /dev/null +++ b/neural_compressor/tensorflow/quantization/utils/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/neural_compressor/tensorflow/quantization/utils/graph_rewriter/__init__.py b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/__init__.py new file mode 100644 index 00000000000..28f108cb636 --- /dev/null +++ b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/__init__.py b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/__init__.py new file mode 100644 index 00000000000..7b9dd4d5e8f --- /dev/null +++ b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/__init__.py @@ -0,0 +1,17 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tensorflow Generic Graph Rewriters.""" diff --git a/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/convert_add_to_biasadd.py b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/convert_add_to_biasadd.py new file mode 100644 index 00000000000..0544019335c --- /dev/null +++ b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/convert_add_to_biasadd.py @@ -0,0 +1,76 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert Add OP to BiasAdd OP Graph Rewriter.""" + +import numpy as np +from tensorflow.python.framework import dtypes, tensor_util + +from neural_compressor.tensorflow.quantization.utils.graph_rewriter.graph_base import GraphRewriterBase +from neural_compressor.tensorflow.quantization.utils.graph_util import GraphAnalyzer +from neural_compressor.tensorflow.quantization.utils.graph_util import GraphRewriterHelper as Helper +from neural_compressor.tensorflow.utils import SPR_BASE_VERSIONS, dump_elapsed_time + + +class ConvertAddToBiasAddOptimizer(GraphRewriterBase): + """Convert MatMul/Conv2D + Add(AddV2) to MatMul + BiasAdd.""" + + @dump_elapsed_time("Pass ConvertAddToBiasAddOptimizer") + def do_transformation(self): + """Execute conversion Add to BiasAdd.""" + g = GraphAnalyzer() + g.graph = self.model + graph_info = g.parse_graph() + + import tensorflow as tf + + if tf.version.VERSION not in SPR_BASE_VERSIONS: + target_nodes = g.query_fusion_pattern_nodes([["MatMul", "Conv2D"], ["Add", "AddV2"]]) + else: + target_nodes = g.query_fusion_pattern_nodes([["MatMul"], ["Add", "AddV2"]]) + for i in target_nodes: + successor_node_names = graph_info[i[1]].outputs + matmul_input_name = graph_info[i[0]].node.input[0] + matmul_input_node = graph_info[Helper.node_name_from_input(matmul_input_name)].node + # Fixme below two lines was added due to MatMul kernel limitation for matmul input type + # should be quint8. + if matmul_input_node.op == "Const": + continue + add_second_input_name = graph_info[i[1]].node.input[1] + add_second_const_node = graph_info[add_second_input_name].node + if add_second_const_node.op != "Const": + continue + bias_tensor = tensor_util.MakeNdarray(add_second_const_node.attr["value"].tensor) + + if bias_tensor.ndim > 2: + continue + + new_bias_tensor = np.ravel(bias_tensor) + + g.remove_node(i[1]) + + bias_node_name = i[1] + bias_const_node_name = add_second_const_node.name + "_flattern" + + bias_const_node = Helper.create_constant_node(bias_const_node_name, new_bias_tensor, dtypes.float32) + + bias_node = Helper.create_node("BiasAdd", bias_node_name, [i[0], bias_const_node_name]) + Helper.set_attr_dtype(bias_node, "T", dtypes.float32) + + g.add_node(bias_const_node, None, [bias_node_name]) + g.replace_single_node(bias_node, [i[0]], i[1], successor_node_names, i[1]) + + return g.dump_graph() diff --git a/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/convert_layout.py b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/convert_layout.py new file mode 100644 index 00000000000..d2a4cfe2c9a --- /dev/null +++ b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/convert_layout.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert Layout Graph Rewriter.""" + +import tensorflow as tf +from tensorflow.core.protobuf import config_pb2, meta_graph_pb2, rewriter_config_pb2 +from tensorflow.python.grappler import tf_optimizer +from tensorflow.python.training import saver as saver_lib + +from neural_compressor.tensorflow.quantization.utils.graph_rewriter.graph_base import GraphRewriterBase +from neural_compressor.tensorflow.utils import dump_elapsed_time, version1_gt_version2 + + +class ConvertLayoutOptimizer(GraphRewriterBase): + """The layout conversion optimizer, convert NCHW to NHWC format. + + It is executed only when NCHW node exists and tensorflow version is 2.4.0 and above. + + Args: model: input graph_def + outputs: output name list + + Return: converted graph_def + """ + + def __init__(self, model, outputs): + """Initialization.""" + super().__init__(model) + self.outputs = outputs + + @dump_elapsed_time("ConvertLayoutOptimizer") + def do_transformation(self): + """Execute converting layout.""" + convert = False + for node in self.model.node: + if "Conv" in node.op and "data_format" in node.attr and node.attr["data_format"].s == b"NCHW": + convert = True + break + if convert and version1_gt_version2(tf.version.VERSION, "2.3.0"): + g = tf.Graph() + with g.as_default(): # pylint: disable=not-context-manager + g = tf.compat.v1.import_graph_def(self.model, name="") + meta_graph = saver_lib.export_meta_graph(graph_def=self.model, graph=g, clear_devices=False) + fetch_collection = meta_graph_pb2.CollectionDef() + for fetch in self.outputs: + fetch_collection.node_list.value.append(fetch) # pylint: disable=no-member + meta_graph.collection_def["train_op"].CopyFrom( # pylint: disable=no-member + fetch_collection + ) # pylint: disable=no-member + + config = config_pb2.ConfigProto() + convert = rewriter_config_pb2.RewriterConfig.NCHW_TO_NHWC # pylint: disable=no-member + config.graph_options.rewrite_options.CopyFrom( # pylint: disable=no-member + rewriter_config_pb2.RewriterConfig( + disable_model_pruning=True, + constant_folding=rewriter_config_pb2.RewriterConfig.OFF, + dependency_optimization=rewriter_config_pb2.RewriterConfig.OFF, + memory_optimization=rewriter_config_pb2.RewriterConfig.NO_MEM_OPT, + arithmetic_optimization=rewriter_config_pb2.RewriterConfig.OFF, + shape_optimization=rewriter_config_pb2.RewriterConfig.OFF, + loop_optimization=rewriter_config_pb2.RewriterConfig.OFF, + function_optimization=rewriter_config_pb2.RewriterConfig.OFF, + remapping=rewriter_config_pb2.RewriterConfig.OFF, + implementation_selector=rewriter_config_pb2.RewriterConfig.OFF, + cpu_layout_conversion=convert, + ) + ) + + optimized_graph = tf_optimizer.OptimizeGraph(config, meta_graph) + return optimized_graph + + return self.model diff --git a/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/convert_leakyrelu.py b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/convert_leakyrelu.py new file mode 100644 index 00000000000..f7cb7eec4a5 --- /dev/null +++ b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/convert_leakyrelu.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert LeakyRelu Graph Rewriter.""" + +from tensorflow.python.framework import dtypes, tensor_util + +from neural_compressor.tensorflow.quantization.utils.graph_util import GraphAnalyzer +from neural_compressor.tensorflow.quantization.utils.graph_util import GraphRewriterHelper as Helper +from neural_compressor.tensorflow.utils import dump_elapsed_time + +from ..graph_base import GraphRewriterBase + + +class ConvertLeakyReluOptimizer(GraphRewriterBase): + """Convert below subgraph to Node A + LeakyRelu. + + Node A Node A + | x | + | x | + | Mul ---> | + | x | + | x | + Maximum LeakyRelu + Note, the coefficient of Mul should be less than 1 or the conversion is not valid. + """ + + @dump_elapsed_time("Pass ConvertLeakyReluOptimizer") + def do_transformation(self): + """Fuse small ops to LeakyRelu.""" + g = GraphAnalyzer() + g.graph = self.model + graph_info = g.parse_graph() + target_nodes = g.query_fusion_pattern_nodes([["Mul"], ["Maximum"]]) + for i in target_nodes: + successor_node_names = graph_info[i[1]].outputs + + mul_input_names = list(graph_info[i[0]].node.input) + max_input_names = list(graph_info[i[1]].node.input) + common_input = list(set(mul_input_names).intersection(set(max_input_names))) + + if len(common_input) != 1: + continue + mul_coeff_node_name = list(set(mul_input_names).difference(set(max_input_names)))[0] + mul_coeff_node = graph_info[mul_coeff_node_name].node + if mul_coeff_node.op != "Const": + continue + nd = tensor_util.MakeNdarray(mul_coeff_node.attr["value"].tensor).ndim + if nd > 1: + continue + alpha_value = float(tensor_util.MakeNdarray(mul_coeff_node.attr["value"].tensor)) + if alpha_value > 1.0: + continue + + leaky_relu_node_name = i[1] + "_leakyrelu" + leaky_relu_node = Helper.create_node("LeakyRelu", leaky_relu_node_name, common_input) + Helper.set_attr_dtype(leaky_relu_node, "T", dtypes.float32) + Helper.set_attr_float(leaky_relu_node, "alpha", alpha_value) + + g.replace_single_node(leaky_relu_node, common_input, i[1], successor_node_names, i[1]) + g.remove_node(i[1]) + g.remove_node(i[0]) + g.remove_node(mul_coeff_node_name) + + return g.dump_graph() diff --git a/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/convert_nan_to_random.py b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/convert_nan_to_random.py new file mode 100644 index 00000000000..1ea158578af --- /dev/null +++ b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/convert_nan_to_random.py @@ -0,0 +1,52 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert NAN to random Graph Rewriter.""" + +import numpy as np +from tensorflow.core.framework import attr_value_pb2 +from tensorflow.python.framework import dtypes, tensor_util + +from neural_compressor.tensorflow.quantization.utils.graph_util import GraphAnalyzer + +from ..graph_base import GraphRewriterBase + + +class ConvertNanToRandom(GraphRewriterBase): + """Convert Const node which value consists of NAN to random data.""" + + def do_transformation(self): + """Execute convert NAN to random.""" + cur_graph = GraphAnalyzer() + cur_graph.graph = self.model + + graph_info = cur_graph.parse_graph() + + target_nodes = cur_graph.query_fusion_pattern_nodes([["Const"]]) + + for i in target_nodes: + const_node = graph_info[i[0]].node + const_content = tensor_util.MakeNdarray(const_node.attr["value"].tensor) + if const_content.dtype == np.float32 and np.any(np.isnan(const_content)): + const_node.attr["value"].CopyFrom( + attr_value_pb2.AttrValue( + tensor=tensor_util.make_tensor_proto( + np.random.rand(*const_content.shape), dtypes.float32, const_content.shape + ) + ) + ) + + return cur_graph.dump_graph() diff --git a/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/convert_placeholder_to_const.py b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/convert_placeholder_to_const.py new file mode 100644 index 00000000000..1908ed26bdd --- /dev/null +++ b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/convert_placeholder_to_const.py @@ -0,0 +1,99 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2022 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Convert placeholder to const Graph Rewriter.""" + +from tensorflow.core.framework import attr_value_pb2, node_def_pb2 +from tensorflow.python.framework import dtypes, tensor_util + +from neural_compressor.tensorflow.quantization.utils.graph_util import GraphAnalyzer +from neural_compressor.tensorflow.quantization.utils.graph_util import GraphRewriterHelper as Helper +from neural_compressor.tensorflow.utils import dump_elapsed_time + +from ..graph_base import GraphRewriterBase + + +class ConvertPlaceholderToConst(GraphRewriterBase): + """Convert placeholder to const for removing training nodes.""" + + @dump_elapsed_time("Pass ConvertPlaceholderToConst") + def do_transformation(self): + """Rename the PlaceHolderWithDefault node to constant. + + In a frozen graph, PlaceholderWithDefault nodes can be converted to + Constant op nodes with same value. This will help simplify the graph. + + Args: + input_graph_def: A GraphDef containing a model. + nodes_to_convert: A list of PlaceholderWithDefault or Placeholder + nodes to be converted to Constants with their new value. + + Returns: + modified graph with PlaceholderWithDefault node converted to Constant node + """ + cur_graph = GraphAnalyzer() + cur_graph.graph = self.model + + graph_info = cur_graph.parse_graph() + + target_nodes = cur_graph.query_fusion_pattern_nodes([["PlaceholderWithDefault"]]) + for i in target_nodes: + placeholder_node = graph_info[i[0]].node + new_node = node_def_pb2.NodeDef() + if dtypes.bool.as_datatype_enum == placeholder_node.attr["dtype"].type: + placeholder_input_node = None + if placeholder_node.input: + placeholder_input_node = graph_info[Helper.node_name_from_input(placeholder_node.input[0])].node + + if placeholder_input_node and placeholder_input_node.op != "Const": + continue + if placeholder_input_node: + new_val_str = placeholder_input_node.attr["value"].tensor.bool_val + else: + continue + + new_node.op = "Const" + new_node.name = placeholder_node.name + "_const" + new_node.attr["dtype"].CopyFrom(placeholder_node.attr["dtype"]) + new_node.attr["value"].CopyFrom( + attr_value_pb2.AttrValue( + tensor=tensor_util.make_tensor_proto(self.strtobool(new_val_str), dtype=dtypes.bool, shape=[]) + ) + ) + cur_graph.add_node(new_node, None, graph_info[placeholder_node.name].outputs) + for each_output in graph_info[placeholder_node.name].outputs: + for i, input_name in enumerate(graph_info[each_output].node.input): + if input_name == placeholder_node.name: + new_input = ( + graph_info[each_output].node.input[:i] + + [new_node.name] + + graph_info[each_output].node.input[i + 1 :] + ) + graph_info[each_output].node.ClearField("input") + graph_info[each_output].node.input.extend(new_input) + cur_graph.remove_node(placeholder_node.name) + else: + continue + + return cur_graph.dump_graph() + + def strtobool(self, val_str): + """Return boolean value of it's equivalent string representation.""" + if val_str == [True]: + return True + if val_str == [False]: + return False + return False diff --git a/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/dilated_contraction.py b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/dilated_contraction.py new file mode 100644 index 00000000000..ecf4c79c88e --- /dev/null +++ b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/dilated_contraction.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Dilated Contraction Graph Rewriter.""" + +from tensorflow.core.framework import attr_value_pb2 +from tensorflow.python.framework import dtypes, tensor_util + +from neural_compressor.tensorflow.quantization.utils.graph_util import GraphAnalyzer +from neural_compressor.tensorflow.quantization.utils.graph_util import GraphRewriterHelper as Helper +from neural_compressor.tensorflow.utils import dump_elapsed_time + +from ..graph_base import GraphRewriterBase + + +class DilatedContraction(GraphRewriterBase): + """Fuse the SpaceToBatchND + Conv + BatchToSpaceND pattern.""" + + @dump_elapsed_time("Pass DilatedContraction") + def do_transformation(self): + """Dilated Contraction fusion.""" + cur_graph = GraphAnalyzer() + cur_graph.graph = self.model + + graph_info = cur_graph.parse_graph() + target_nodes = cur_graph.query_fusion_pattern_nodes( + ["SpaceToBatchND", ["Conv2D", "DepthwiseConv2dNative"], "BatchToSpaceND"] + ) + + for node_combination in target_nodes: + stob_node = graph_info[node_combination[0]].node + contraction_node = graph_info[node_combination[1]].node + btos_node = graph_info[node_combination[2]].node + stob_padding_node = graph_info[stob_node.input[2]].node + + block_shape_node = graph_info[btos_node.input[1]].node + crops_node = graph_info[btos_node.input[2]].node + + block_value = [i for i in tensor_util.MakeNdarray(block_shape_node.attr["value"].tensor).flat] + new_dilation = [1, block_value[0], block_value[1], 1] + # if padding input of SpaceToBatchND can't be directly fetched, we continue + if stob_padding_node.op != "Const": + continue + padding_value = [i for i in tensor_util.MakeNdarray(stob_padding_node.attr["value"].tensor).flat] + crops_value = [i for i in tensor_util.MakeNdarray(crops_node.attr["value"].tensor).flat] + + contraction_node.input[0] = stob_node.input[0] + Helper.set_attr_int_list(contraction_node, "dilations", new_dilation) + + real_padding = [padding_value[i] - crops_value[i] for i in range(4)] + explict_padding = [0, 0, 0, 0, 0, 0, 0, 0] + data_format = contraction_node.attr["data_format"].s.decode() + if any(real_padding): + contraction_node.attr["padding"].s = "EXPLICIT".encode() + assert data_format in ("NHWC", "NCHW") + if data_format == "NHWC": + explict_padding[2] = real_padding[0] + explict_padding[3] = real_padding[1] + explict_padding[4] = real_padding[2] + explict_padding[5] = real_padding[3] + else: + explict_padding[4] = real_padding[0] + explict_padding[5] = real_padding[1] + explict_padding[6] = real_padding[2] + explict_padding[7] = real_padding[3] + Helper.set_attr_int_list(contraction_node, "explicit_paddings", explict_padding) + + contraction_node.attr.pop("_output_shapes") + cur_graph.remove_node(stob_node.name) + following_node_name = graph_info[node_combination[2]].outputs[0] + following_node = graph_info[following_node_name].node + + following_node.input[0] = btos_node.input[0] + cur_graph.remove_node(btos_node.name) + + return cur_graph.dump_graph() diff --git a/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/dummy_biasadd.py b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/dummy_biasadd.py new file mode 100644 index 00000000000..077b62684a0 --- /dev/null +++ b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/dummy_biasadd.py @@ -0,0 +1,130 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inject dummy BiasAdd Graph Rewriter.""" + +import tensorflow as tf +from tensorflow.python.framework import dtypes + +from neural_compressor.tensorflow.quantization.utils.graph_rewriter.graph_base import GraphRewriterBase +from neural_compressor.tensorflow.quantization.utils.graph_util import GraphAnalyzer +from neural_compressor.tensorflow.quantization.utils.graph_util import GraphRewriterHelper as Helper +from neural_compressor.tensorflow.utils import dump_elapsed_time + + +class InjectDummyBiasAddOptimizer(GraphRewriterBase): + """Inject dummy BiasAdd for MatMul, Conv2D for pattern fusion.""" + + def __init__(self, model, outputs): + """Initialization.""" + super().__init__(model) + self.outputs = outputs + + @dump_elapsed_time("Pass InjectDummyBiasAddOptimizer") + def do_transformation(self): + """Inject dummy BiasAdd if MatMul, Conv2D missing the valid add ops behind them.""" + g = GraphAnalyzer() + g.graph = self.model + graph_info = g.parse_graph() + g.get_frame_info() + valid_ops = ("BiasAdd", "Add", "AddV2", "AddN") + target_nodes = g.query_fusion_pattern_nodes( + [ + ["MatMul", "Conv2D"], + ] + ) + for i in target_nodes: + # only apply this pass for tensorflow old quantization API, pre_optimize does this check + # use conv+dummy_biasadd+relu because TF do not support conv+relu now. + if i[0] in self.outputs: + continue + next_node_names = graph_info[i[0]].outputs + if ( + next_node_names + and len(next_node_names) == 1 + and graph_info[Helper.node_name_from_input(next_node_names[0])].node.op in valid_ops + ): + continue + bias_node_name = i[0] + "_dummy_biasadd" + bias_const_node_name = i[0] + "_dummy_biasadd_const" + matmul_a_node_name = Helper.node_name_from_input(graph_info[i[0]].node.input[0]) + matmul_a_node = graph_info[matmul_a_node_name].node + matmul_b_node_name = Helper.node_name_from_input(graph_info[i[0]].node.input[1]) + matmul_b_node = graph_info[matmul_b_node_name].node + + if matmul_a_node.op == "Const" or matmul_b_node.op not in ["Const", "Enter"]: + continue + if matmul_b_node.op == "Enter": # pragma: no cover + parent_node = graph_info[Helper.node_name_from_input(matmul_b_node.input[0])].node + if parent_node.op != "Const": + continue + else: + matmul_b_node = parent_node + matmul_b_node_name = matmul_b_node.name + + if graph_info[i[0]].node.op == "MatMul": + t_b_index = 0 if graph_info[i[0]].node.attr["transpose_b"].b else 1 + elif graph_info[i[0]].node.op == "Conv2D" and graph_info[i[0]].node.attr["data_format"].s == b"NHWC": + t_b_index = 3 + elif graph_info[i[0]].node.op == "Conv2D" and graph_info[i[0]].node.attr["data_format"].s == b"NCHW": + t_b_index = 1 + else: + continue + + bias_add_length = matmul_b_node.attr["value"].tensor.tensor_shape.dim[t_b_index].size + + bias_add_content = [0.0] * bias_add_length + + bias_const_node = Helper.create_constant_node( + bias_const_node_name, bias_add_content, dtypes.float32, shape=[bias_add_length] + ) + + if i[0] in g.parent_frame_details and g.parent_frame_details[i[0]]: # pragma: no cover + bias_const_enter_node = Helper.create_node( + "Enter", bias_const_node_name + "_enter", [bias_const_node_name] + ) + Helper.set_attr_string( + bias_const_enter_node, "frame_name", g.parent_frame_details[i[0]].attr["frame_name"].s + ) + Helper.set_attr_dtype(bias_const_enter_node, "T", dtypes.float32) + Helper.set_attr_bool(bias_const_enter_node, "is_constant", True) + Helper.set_attr_int( + bias_const_enter_node, + "parallel_iterations", + g.parent_frame_details[i[0]].attr["parallel_iterations"].i, + ) + + bias_node = Helper.create_node( + "BiasAdd", + bias_node_name, + [ + i[0], + ( + bias_const_enter_node.name + if i[0] in g.parent_frame_details and g.parent_frame_details[i[0]] + else bias_const_node_name + ), + ], + ) + Helper.set_attr_dtype(bias_node, "T", dtypes.float32) + g.add_node(bias_node, i[0], next_node_names) + if i[0] in g.parent_frame_details and g.parent_frame_details[i[0]]: # pragma: no cover + g.add_node(bias_const_node, None, [bias_const_enter_node.name]) + g.add_node(bias_const_enter_node, bias_const_node_name, [bias_node_name]) + else: + g.add_node(bias_const_node, None, [bias_node_name]) + + return g.dump_graph() diff --git a/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/expanddims_optimizer.py b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/expanddims_optimizer.py new file mode 100644 index 00000000000..96ba8cc4514 --- /dev/null +++ b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/expanddims_optimizer.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""ExpandDims Graph Rewriter.""" + + +import numpy as np +from tensorflow.python.framework import dtypes + +from neural_compressor.tensorflow.quantization.utils.graph_util import GraphAnalyzer +from neural_compressor.tensorflow.quantization.utils.graph_util import GraphRewriterHelper as Helper +from neural_compressor.tensorflow.utils import dump_elapsed_time + +from ..graph_base import GraphRewriterBase + + +class ExpandDimsOptimizer(GraphRewriterBase): + """Calculate ExpandDims and remove it if its input is weight and next node is Conv2D.""" + + @dump_elapsed_time("Pass ExpandDimsOptimizer") + def do_transformation(self): + """Handle all ExpandDims ops whose input is weight and output is Conv2D. + + Args: + input_graph_def (graphdef): graphdef object + + Returns: + [graphdef]: optimized graph + """ + cur_graph = GraphAnalyzer() + cur_graph.graph = self.model + + graph_info = cur_graph.parse_graph() + target_nodes = cur_graph.query_fusion_pattern_nodes([["ExpandDims"]]) + + for node_combination in target_nodes: + expanddims_node = graph_info[node_combination[0]].node + dims_node = graph_info[expanddims_node.input[1]].node + next_node = graph_info[graph_info[node_combination[0]].outputs[0]].node + # to solve the case that input 0 of ExpandDims is a tensor, not a node + if expanddims_node.input[0] in graph_info: + weight_node = graph_info[expanddims_node.input[0]].node + else: + continue + + if weight_node.op == "Const" and next_node.op == "Conv2D": + dims = Helper.values_from_const(dims_node) + weight_value = np.array(Helper.values_from_const(weight_node)) + new_weight_value = np.expand_dims(weight_value, axis=dims) + cur_graph.remove_node(weight_node.name) + new_weight_node = Helper.create_constant_node(weight_node.name, new_weight_value, dtypes.float32) + + for output in graph_info[node_combination[0]].outputs: + successor_node = graph_info[output].node + replace_index = None + for index, value in enumerate(successor_node.input): + if value == expanddims_node.name: + replace_index = index + break + # weight->conv2d + cur_graph.add_node(new_weight_node, None, [successor_node.name]) + successor_node.input[replace_index] = new_weight_node.name + # remove ExpandDims and weight_node + cur_graph.remove_node(expanddims_node.name) + else: + continue + + return cur_graph.dump_graph() diff --git a/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/fetch_weight_from_reshape.py b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/fetch_weight_from_reshape.py new file mode 100644 index 00000000000..cc3f6934732 --- /dev/null +++ b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/fetch_weight_from_reshape.py @@ -0,0 +1,89 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fetch Weight from Reshape Graph Rewriter.""" + + +import numpy as np +from tensorflow.python.framework import dtypes + +from neural_compressor.tensorflow.quantization.utils.graph_util import GraphAnalyzer +from neural_compressor.tensorflow.quantization.utils.graph_util import GraphRewriterHelper as Helper +from neural_compressor.tensorflow.utils import dump_elapsed_time + +from ..graph_base import GraphRewriterBase + + +class FetchWeightFromReshapeOptimizer(GraphRewriterBase): + """Handle the Pack + Reshape + Conv2D fusion pattern.""" + + @dump_elapsed_time("Pass FetchWeightFromReshapeOptimizer") + def do_transformation(self): + """Fetch weight of Conv2D from Pack+Reshape+Conv2D pattern. + + Args: + input_graph_def (graphdef): graphdef object + Returns: + [graphdef]: optimized graph + """ + cur_graph = GraphAnalyzer() + cur_graph.graph = self.model + + graph_info = cur_graph.parse_graph() + target_nodes = cur_graph.query_fusion_pattern_nodes([["Pack"], ["Reshape"], ["Conv2D"]]) + + for i, node_combination in enumerate(target_nodes): + pack_node = graph_info[node_combination[0]].node + reshape_node = graph_info[node_combination[1]].node + shape_node = graph_info[reshape_node.input[1]].node + conv_node = graph_info[node_combination[2]].node + if not (pack_node.op == "Pack" and reshape_node.op == "Reshape" and conv_node.op == "Conv2D"): + continue + reshape_outputs_length = len(graph_info[node_combination[1]].outputs) + unpack_values = [] + for index in range(pack_node.attr["N"].i): + values_node = graph_info[pack_node.input[index]].node + if values_node.op == "Const": + unpack_values.append(Helper.values_from_const(values_node)) + input_reshape = np.stack(unpack_values, axis=pack_node.attr["axis"].i) + if shape_node.op != "Const": + continue + shape = Helper.values_from_const(shape_node) + weight = np.reshape(input_reshape, shape) + weight_node = Helper.create_constant_node( + reshape_node.name + "/weight" + "_" + str(i), weight, dtypes.float32 + ) + if i > 0: + conv_node_j = graph_info[target_nodes[i - 1][2]].node + graph_info[node_combination[1]].outputs.remove(conv_node_j.name) + for output in graph_info[node_combination[1]].outputs: + successor_node = graph_info[output].node + replace_index = None + for index, value in enumerate(successor_node.input): + if value == reshape_node.name or value == reshape_node.name + "/weight" + "_" + str(i - 1): + replace_index = index + break + # weight->conv2d + cur_graph.add_node(weight_node, None, [successor_node.name]) + successor_node.input[replace_index] = weight_node.name + + if i + 1 == reshape_outputs_length: + cur_graph.remove_node(reshape_node.name) + cur_graph.remove_node(values_node.name) + cur_graph.remove_node(shape_node.name) + cur_graph.remove_node(pack_node.name) + + return cur_graph.dump_graph() diff --git a/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/fold_batch_norm.py b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/fold_batch_norm.py new file mode 100644 index 00000000000..bff69330ec5 --- /dev/null +++ b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/fold_batch_norm.py @@ -0,0 +1,271 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Folding BatchNorm Graph Rewriter.""" + +import math + +import numpy as np +from tensorflow.core.framework import attr_value_pb2, node_def_pb2 +from tensorflow.python.framework import tensor_util + +from neural_compressor.tensorflow.quantization.utils.graph_util import GraphAnalyzer +from neural_compressor.tensorflow.quantization.utils.graph_util import GraphRewriterHelper as Helper +from neural_compressor.tensorflow.utils import dump_elapsed_time + +from ..graph_base import GraphRewriterBase + + +class FoldBatchNormNodesOptimizer(GraphRewriterBase): + """Folding BatchNorm nodes into Conv.""" + + INPUT_ORDER = { + # Order of inputs for BatchNormWithGlobalNormalization. + "BatchNormWithGlobalNormalization": ["conv_op", "mean_op", "var_op", "beta_op", "gamma_op"], + # Order of inputs for FusedBatchNorm. + "FusedBatchNorm": ["conv_op", "gamma_op", "beta_op", "mean_op", "var_op"], + "FusedBatchNormV3": ["conv_op", "gamma_op", "beta_op", "mean_op", "var_op"], + "_FusedBatchNormEx": ["conv_op", "gamma_op", "beta_op", "mean_op", "var_op"], + } + # Name of the attribute epsilon value is stored in. + EPSILON_ATTR = { + "BatchNormWithGlobalNormalization": "variance_epsilon", + "FusedBatchNorm": "epsilon", + "FusedBatchNormV3": "epsilon", + "_FusedBatchNormEx": "epsilon", + } + + def scale_after_normalization(self, node): + """Check the scale_after_normalization attribute if the node is BatchNormWithGlobalNormalization. + + Args: + node (nodedef): input nodedef object + + Returns: + bool: True if the node op is not BatchNormWithGlobalNormalization else it + depends on the BatchNormWithGlobalNormalization attribute value of + `scale_after_normalization`. + """ + if node.op == "BatchNormWithGlobalNormalization": + return node.attr["scale_after_normalization"].b + return True + + @dump_elapsed_time("Pass FoldBatchNormNodesOptimizer") + def do_transformation(self): + """Removes batch normalization ops by folding them into convolutions. + + Batch normalization during training has multiple dynamic parameters that are + updated, but once the graph is finalized these become constants. That means + there's an opportunity to reduce the computations down to a scale and + addition, rather than the more expensive multiple ops, and even bake the + scaling into the convolution weights. This function identifies the typical + pattern of batch normalization subgraphs, and performs the transformation to + fold the computations down into a simpler form. It currently only spots batch + normalization that's performed by the BatchNormWithGlobalNormalization, FusedBatchNorm, + FusedBatchNormV3 and _FusedBatchNormEx ops, and will need to be extended in the future to handle the + newer style. + + Returns: + Modified graph with BN ops removed, and modified weights. + + Raises: + ValueError: If the graph is badly formed with duplicate node names. + """ + cur_graph = GraphAnalyzer() + cur_graph.graph = self.model + + graph_info = cur_graph.parse_graph() + target_nodes = cur_graph.query_fusion_pattern_nodes( + [ + ["Conv2D", "DepthwiseConv2dNative"], + ("BiasAdd", "Add", "AddV2"), + ["BatchNormWithGlobalNormalization", "FusedBatchNorm", "FusedBatchNormV3", "_FusedBatchNormEx"], + ] + ) + for node_combination in target_nodes: + matched_node = node_combination[:-1] + has_add_op = True if len(node_combination[-1]) == 3 else False + conv_node = graph_info[Helper.node_name_from_input(matched_node[0])].node + weights_node_name = graph_info[Helper.node_name_from_input(matched_node[0])].node.input[1] + weights_node = graph_info[Helper.node_name_from_input(weights_node_name)].node + bn_node = graph_info[Helper.node_name_from_input(matched_node[-1])].node + + # oneDNN enabled _FusedBatchNormEx only supports num_side_inputs == 0 + # and Relu/Identity activations. + if bn_node.op == "_FusedBatchNormEx": + if bn_node.attr["num_side_inputs"].i != 0: + continue + if not ( + bn_node.attr["activation_mode"].s == b"Identity" or bn_node.attr["activation_mode"].s == b"Relu" + ): + continue + + if weights_node.op != "Const": + self.logger.warning( + "Didn't find expected conv Constant input to '%s', " + "found %s instead. Maybe freeze_graph wasn't " + "run first?" % (bn_node.name, weights_node_name) + ) + continue + weights = Helper.values_from_const(weights_node) + + if conv_node.op == "Conv2D": + channel_count = weights.shape[3] + elif conv_node.op == "DepthwiseConv2dNative": + channel_count = weights.shape[2] * weights.shape[3] + + mean_node_name = Helper.node_name_from_input(bn_node.input[self.INPUT_ORDER[bn_node.op].index("mean_op")]) + mean_node = graph_info[mean_node_name].node + + if mean_node.op != "Const": + continue + + mean_value = Helper.values_from_const(mean_node) + + if has_add_op: + bias_node_name = graph_info[Helper.node_name_from_input(matched_node[1])].node.input[1] + bias_node = graph_info[Helper.node_name_from_input(bias_node_name)].node + if bias_node.op != "Const": + continue + + if mean_value.shape != (channel_count,): + continue + + mean_value = mean_value - Helper.values_from_const(bias_node) + cur_graph.remove_node(bias_node.name) + cur_graph.remove_node(matched_node[1]) + + if mean_value.shape != (channel_count,): + self.logger.warning( + "Incorrect shape for mean, found {}, expected {}, " + "for node {}.".format(str(mean_value.shape), str((channel_count,)), conv_node.name) + ) + continue + var_node_name = Helper.node_name_from_input(bn_node.input[self.INPUT_ORDER[bn_node.op].index("var_op")]) + var_node = graph_info[var_node_name].node + if var_node.op != "Const": + continue + var_value = Helper.values_from_const(var_node) + + if var_value.shape != (channel_count,): + continue + + beta_node_name = Helper.node_name_from_input(bn_node.input[self.INPUT_ORDER[bn_node.op].index("beta_op")]) + beta_node = graph_info[beta_node_name].node + if beta_node.op != "Const": + continue + beta_value = Helper.values_from_const(beta_node) + + if beta_value.shape != (channel_count,): + continue + + gamma_node_name = Helper.node_name_from_input(bn_node.input[self.INPUT_ORDER[bn_node.op].index("gamma_op")]) + gamma_node = graph_info[gamma_node_name].node + + if gamma_node.op != "Const": + continue + gamma_value = Helper.values_from_const(gamma_node) + + if gamma_value.shape != (channel_count,): + continue + + variance_epsilon_value = bn_node.attr[self.EPSILON_ATTR[bn_node.op]].f + + if self.scale_after_normalization(bn_node): + scale_value = (1.0 / np.vectorize(math.sqrt)(var_value + variance_epsilon_value)) * gamma_value + else: + scale_value = 1.0 / np.vectorize(math.sqrt)(var_value + variance_epsilon_value) + + offset_value = (-mean_value * scale_value) + beta_value + + if conv_node.op == "Conv2D": + original_shape = weights.shape + tmp_shape = (original_shape[-1], int(weights.size / original_shape[-1])) + tmp_order = [weights.ndim - 1] + [i for i in range(weights.ndim - 1)] + scaled_weights = np.copy(weights).transpose(tmp_order).ravel().reshape(tmp_shape) + reshape_scale = np.array(scale_value).reshape(len(scale_value), 1) + scaled_weights = np.multiply(scaled_weights, reshape_scale).transpose().reshape(original_shape) + elif conv_node.op == "DepthwiseConv2dNative": + scaled_weights = np.copy(weights) + it = np.nditer(scaled_weights, flags=["multi_index"], op_flags=["readwrite"]) + channel_multiplier = weights.shape[3] + while not it.finished: + current_scale = scale_value[it.multi_index[2] * channel_multiplier + it.multi_index[3]] + it[0] *= current_scale + it.iternext() + + scaled_weights_node = node_def_pb2.NodeDef() + scaled_weights_node.op = "Const" + scaled_weights_node.name = weights_node_name + "_bn_offset" + scaled_weights_node.attr["dtype"].CopyFrom(weights_node.attr["dtype"]) + scaled_weights_node.attr["value"].CopyFrom( + attr_value_pb2.AttrValue( + tensor=tensor_util.make_tensor_proto(scaled_weights, weights.dtype.type, weights.shape) + ) + ) + cur_graph.replace_const_node(scaled_weights_node, [conv_node.name], weights_node_name) + + offset_node = node_def_pb2.NodeDef() + offset_node.op = "Const" + offset_node.name = conv_node.name + "_bn_offset" + offset_node.attr["dtype"].CopyFrom(mean_node.attr["dtype"]) + offset_node.attr["value"].CopyFrom( + attr_value_pb2.AttrValue( + tensor=tensor_util.make_tensor_proto(offset_value, mean_value.dtype.type, offset_value.shape) + ) + ) + bias_add_node = node_def_pb2.NodeDef() + bias_add_node.op = "BiasAdd" + bias_add_node.name = bn_node.name + bias_add_node.attr["T"].CopyFrom(conv_node.attr["T"]) + bias_add_node.attr["data_format"].CopyFrom(conv_node.attr["data_format"]) + bias_add_node.input.extend([conv_node.name, offset_node.name]) + if bn_node.op == "_FusedBatchNormEx" and bn_node.attr["activation_mode"].s == b"Relu": + # Create Relu op which takes Bias-Add as input. + # Conv2D/Depthwise-Conv2D Conv2D/Depthwise-Conv2D + # | | + # Bias-Add (originally _FusedBatchNormEx) <----> Bias-Add + # | | \ + # Relu + relu_node = node_def_pb2.NodeDef() + relu_node.op = "Relu" + relu_node.name = bn_node.name + "_bn_relu" + relu_node.attr["T"].CopyFrom(conv_node.attr["T"]) + relu_node.input.extend([bias_add_node.name]) + + cur_graph.add_node(offset_node, [], [bias_add_node.name]) + cur_graph.add_node( + bias_add_node, conv_node.name, graph_info[Helper.node_name_from_input(matched_node[-1])].outputs + ) + if bn_node.op == "_FusedBatchNormEx" and bn_node.attr["activation_mode"].s == b"Relu": + matchd_node_outputs = graph_info[Helper.node_name_from_input(matched_node[-1])].outputs + cur_graph.add_node(offset_node, [], [bias_add_node.name]) + cur_graph.add_node(bias_add_node, conv_node.name, [relu_node.name]) + cur_graph.add_node(relu_node, bias_add_node.name, matchd_node_outputs) + else: + cur_graph.add_node(offset_node, [], [bias_add_node.name]) + cur_graph.add_node( + bias_add_node, conv_node.name, graph_info[Helper.node_name_from_input(matched_node[-1])].outputs + ) + cur_graph.replace_const_node(scaled_weights_node, [conv_node.name], weights_node_name) + + cur_graph.remove_node(weights_node_name) + cur_graph.remove_node(mean_node_name) + cur_graph.remove_node(var_node_name) + cur_graph.remove_node(beta_node_name) + cur_graph.remove_node(gamma_node_name) + + return cur_graph.dump_graph() diff --git a/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/fold_constant.py b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/fold_constant.py new file mode 100644 index 00000000000..924536db696 --- /dev/null +++ b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/fold_constant.py @@ -0,0 +1,182 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Folding Const Graph Rewriter.""" + + +import numpy as np +import tensorflow as tf +from tensorflow.python.platform import tf_logging + +from neural_compressor.tensorflow.quantization.utils.graph_util import GraphAnalyzer, GraphRewriterHelper +from neural_compressor.tensorflow.utils import dump_elapsed_time + +from ..graph_base import GraphRewriterBase + + +class GraphFoldConstantOptimizer(GraphRewriterBase): + """Folding all the sequences only consist of const and self.supported_op_type.""" + + supported_op_type = ["Add", "AddV2", "Const", "Mul", "Rsqrt", "Sub"] + + def __init__(self, model=None): + """Initialization.""" + super().__init__(model) + self.graph_analyzer = GraphAnalyzer() + self.graph_analyzer.graph = self.model + + self.graph_info = self.graph_analyzer.parse_graph() + + def _fold_value(self, end_node_name): + """Calculate values of end node of constant node sequence. + + there may be layers whose inputs are all constant in the graph, like: + const + > add + const + the value of add can be calculated in advance. + + Args: + end_node_name: name of the end node of the sequence. e.g. add in the above examples. + + Returns: + values of end node. + + Raises: + ValueError: If the graph contains tensors which can't be broadcast. + """ + end_node = self.graph_info[end_node_name].node + + def can_broadcast(s1, s2): + if s1.shape and s2.shape: + s1a = np.asarray(s1.shape) + s2a = np.asarray(s2.shape) + return ((s1a == 1) | (s2a == 1) | (s2a == s1a)).all() + + return True + + if self.graph_info[end_node_name].node.input: + if end_node.op == "Mul": + first_value = self._fold_value(list(end_node.input)[0]) + first_type = first_value.dtype + fold_value = np.array(1.0).astype(first_type) + for index, input in enumerate(end_node.input): + # broadcast if needed + input_value = self._fold_value(input) + input_type = input_value.dtype + if can_broadcast(fold_value, input_value): + fold_value = fold_value * input_value + else: + raise ValueError("input {} of node {} can't be broadcast".format(input.name, end_node.name)) + return fold_value.astype(first_type) + elif end_node.op == "Add" or end_node.op == "AddV2": + first_value = self._fold_value(list(end_node.input)[0]) + first_type = first_value.dtype + fold_value = np.array(0.0).astype(first_type).reshape(()) + for index, input in enumerate(end_node.input): + # broadcast if needed + input_value = self._fold_value(input) + if can_broadcast(fold_value, input_value): + fold_value = fold_value + input_value + else: + raise ValueError("input {} of node {} can't be broadcast".format(input.name, end_node.name)) + return fold_value.astype(first_type) + elif end_node.op == "Rsqrt": + return 1 / np.sqrt(self._fold_value(end_node.input[0])) + elif end_node.op == "Sub": + first_value = self._fold_value(list(end_node.input)[0]) + first_type = first_value.dtype + fold_value = np.array(0.0, dtype=first_type) + for index, input in enumerate(end_node.input): + # broadcast if needed + input_value = self._fold_value(input) + if first_type != input_value.dtype: + raise ValueError( + "input of node {} must be in same dtype but get {}and {}".format( + input.name, first_type, input_value.dtype + ) + ) + if can_broadcast(fold_value, input_value): + fold_value = fold_value + (-1) ** index * input_value + else: + raise ValueError("input {} of node {} can't be broadcast".format(input.name, end_node.name)) + return fold_value.astype(first_type) + else: + tf_logging.info( + "Currently fold-constant only support limited ops {} but face {}".format( + self.supported_op_type, end_node.op + ) + ) + else: + return GraphRewriterHelper.values_from_const(end_node) + + def check_all_folded(self): + """Check the node has been folded completely. + + Returns: + bool: True if the node has been folded else False. + """ + for node_name, _ in self.graph_info.items(): + if self.check_const_inputs(node_name): + return False + return True + + def check_const_inputs(self, node_name): + """Check the node has the const input. + + Args: + node_name (string): node name + + Returns: + bool: True if the node has the const input else False + """ + if node_name not in self.graph_info: + return False + node_op = self.graph_info[node_name].node.op + if node_op == "Placeholder" or node_op == "Const": + return False + if node_op not in self.supported_op_type: + return False + constant_flag = True + for input_name in self.graph_info[node_name].node.input: + input_name = GraphRewriterHelper.node_name_from_input(input_name) + input_node = self.graph_info[input_name].node + constant_flag &= input_node.op == "Const" and not input_node.input + return constant_flag + + @dump_elapsed_time("Pass GraphFoldConstantOptimizer") + def do_transformation(self): + """Fold all the sequences only consist of const and self.supported_op_type. + + Args: + input_graph_def (graphdef): graphdef object + + Returns: + [graphdef]: optimized graph + """ + while not self.check_all_folded(): + for node_name, _ in self.graph_info.copy().items(): + if self.check_const_inputs(node_name): + fold_value = self._fold_value(node_name) + fold_type = tf.as_dtype(fold_value.dtype) + new_constant_node = GraphRewriterHelper.create_constant_node( + node_name + "_const", fold_value, fold_type + ) + self.graph_analyzer.replace_constant_graph_with_constant_node(new_constant_node, node_name) + + output_graph_def = self.graph_analyzer.dump_graph() + + return output_graph_def diff --git a/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/fuse_biasadd_add.py b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/fuse_biasadd_add.py new file mode 100644 index 00000000000..e391d334cf5 --- /dev/null +++ b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/fuse_biasadd_add.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fuse BiasAdd and Add Graph Rewriter.""" + +import tensorflow as tf +from tensorflow.python.framework import dtypes, tensor_util + +from neural_compressor.tensorflow.quantization.utils.graph_util import GraphAnalyzer +from neural_compressor.tensorflow.quantization.utils.graph_util import GraphRewriterHelper as Helper + +from ..graph_base import GraphRewriterBase + + +class FuseBiasAddAndAddOptimizer(GraphRewriterBase): + """Fuse Biasadd + Add into BiasAdd when the second input of Add is const node.""" + + def do_transformation(self): + """Fuse Biasadd + Add into BiasAdd for pattern fusion.""" + cur_graph = GraphAnalyzer() + cur_graph.graph = self.model + + graph_info = cur_graph.parse_graph() + + target_nodes = cur_graph.query_fusion_pattern_nodes( + [["Conv2D", "Conv3D"], "BiasAdd", ["Add", "AddV2"], ["Relu", "Relu6", "swish_f32"], ["Mul"], ["Mul"]] + ) + + for i in target_nodes: + biasadd_const_name = graph_info[i[1]].node.input[1] + biasadd_const_node = graph_info[biasadd_const_name].node + + if len(graph_info[i[1]].outputs) > 1: + continue + + another_node_index = None + for index, value in enumerate(graph_info[i[2]].node.input): + if value != i[1]: + another_node_index = index + break + add_node_const_name = graph_info[i[2]].node.input[another_node_index] + + add_const_node = graph_info[add_node_const_name].node + + if add_const_node.op != "Const": + continue + value = tensor_util.MakeNdarray(biasadd_const_node.attr["value"].tensor) + add_value = tensor_util.MakeNdarray(add_const_node.attr["value"].tensor) + + new_bias_tensor = value + add_value + fused_const_node = Helper.create_constant_node(i[2] + "_fused", new_bias_tensor, dtypes.float32) + cur_graph.remove_node(graph_info[i[1]].node.input[1]) + + graph_info[i[1]].node.input[1] = i[2] + "_fused" + + cur_graph.remove_node(add_node_const_name) + + cur_graph.remove_node(i[2]) + graph_info[i[3]].node.input[0] = i[1] + cur_graph.add_node(fused_const_node, None, [i[1]]) + + return cur_graph.dump_graph() diff --git a/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/fuse_column_wise_mul.py b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/fuse_column_wise_mul.py new file mode 100644 index 00000000000..2fc7dbee586 --- /dev/null +++ b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/fuse_column_wise_mul.py @@ -0,0 +1,86 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fuse Columnwise Mul Graph Rewriter.""" + +from tensorflow.core.framework import attr_value_pb2 +from tensorflow.python.framework import dtypes, tensor_util + +from neural_compressor.tensorflow.quantization.utils.graph_util import GraphAnalyzer +from neural_compressor.tensorflow.quantization.utils.graph_util import GraphRewriterHelper as Helper +from neural_compressor.tensorflow.utils import dump_elapsed_time + +from ..graph_base import GraphRewriterBase + + +class FuseColumnWiseMulOptimizer(GraphRewriterBase): + """Fuse Mul op into Conv2D/DepthwiseConv2dNative/MatMul.""" + + @dump_elapsed_time("Pass FuseColumnWiseMulOptimizer") + def do_transformation(self): + """Fuse Mul + Conv2D/DepthwiseConv2dNative/MatMul --> Conv2D/DepthwiseConv2dNative/MatMul.""" + cur_graph = GraphAnalyzer() + cur_graph.graph = self.model + + graph_info = cur_graph.parse_graph() + target_nodes = cur_graph.query_fusion_pattern_nodes([["Conv2D", "DepthwiseConv2dNative", "MatMul"], "Mul"]) + + for node_combination in target_nodes: + upper_node = graph_info[node_combination[0]].node + mul_node = graph_info[node_combination[1]].node + if graph_info[Helper.node_name_from_input(mul_node.input[1])].node.op != "Const": + continue + weights_node = graph_info[graph_info[node_combination[0]].node.input[1]].node + mul_value_node = graph_info[graph_info[node_combination[1]].node.input[1]].node + upper_node_type = upper_node.op + + if upper_node_type == "Conv2D": + weights_col = weights_node.attr["value"].tensor.tensor_shape.dim[3].size + elif upper_node_type == "DepthwiseConv2dNative": + weights_col = ( + weights_node.attr["value"].tensor.tensor_shape.dim[2].size + * weights_node.attr["value"].tensor.tensor_shape.dim[3].size + ) + else: + weights_col = weights_node.attr["value"].tensor.tensor_shape.dim[1].size + + mul_value_node_tensor = mul_value_node.attr["value"].tensor + weights_node_tensor = weights_node.attr["value"].tensor + if ( + len(mul_value_node_tensor.tensor_shape.dim) != 1 + or mul_value_node_tensor.tensor_shape.dim[0].size != weights_col + ): + self.logger.warning("Invalid Mul OP fusion.") + return self.model + + mul_value_node_list = [i for i in tensor_util.MakeNdarray(mul_value_node_tensor).flat] + new_weights = [] + for index, i in enumerate(tensor_util.MakeNdarray(weights_node_tensor).flat): + new_weights_value = i * mul_value_node_list[index % len(mul_value_node_list)] + new_weights.append(new_weights_value) + + weights_node.attr["value"].CopyFrom( + attr_value_pb2.AttrValue( + tensor=tensor_util.make_tensor_proto( + new_weights, dtypes.float32, tensor_util.MakeNdarray(weights_node_tensor).shape + ) + ) + ) + + cur_graph.remove_node_with_single_input_output(mul_node.name) + cur_graph.remove_node(mul_node.input[1]) + + return cur_graph.dump_graph() diff --git a/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/fuse_conv_with_math.py b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/fuse_conv_with_math.py new file mode 100644 index 00000000000..173a6df2525 --- /dev/null +++ b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/fuse_conv_with_math.py @@ -0,0 +1,111 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fuse Conv with Math Graph Rewriter.""" + +import numpy as np +from tensorflow.python.framework import dtypes, tensor_util + +from neural_compressor.tensorflow.quantization.utils.graph_util import GraphAnalyzer +from neural_compressor.tensorflow.quantization.utils.graph_util import GraphRewriterHelper as Helper +from neural_compressor.tensorflow.utils import dump_elapsed_time + +from ..graph_base import GraphRewriterBase + + +class FuseConvWithMathOptimizer(GraphRewriterBase): + """Convert below subgraph to Conv2D + BiasAdd by eliminating math ops. + + Conv2D Conv2D + | | + Sub | + | ----> | + RealDiv | + | | + Mul | + | | + BiasAdd BiasAdd + """ + + @dump_elapsed_time("Pass FuseConvWithMathOptimizer") + def do_transformation(self): + """Fuse Conv + Sub + RealDiv + Mul + BiasAdd to Conv + BiasAdd.""" + g = GraphAnalyzer() + g.graph = self.model + graph_info = g.parse_graph() + pattern_definition = [["Conv2D"], ["Sub"], ["RealDiv"], ["Mul"], ["BiasAdd"]] + target_nodes = g.query_fusion_pattern_nodes(pattern_definition) + for i in target_nodes: + weights_node_name = graph_info[i[0]].node.input[1] + weights_node = graph_info[weights_node_name].node + + sub_input_names = list(graph_info[i[1]].node.input) + sub_content_node_name = list(set(sub_input_names).difference([i[0]]))[0] + sub_content_node = graph_info[sub_content_node_name].node + sub_tensor = tensor_util.MakeNdarray(sub_content_node.attr["value"].tensor) + + real_div_input_names = list(graph_info[i[2]].node.input) + real_div_content_node_name = list(set(real_div_input_names).difference([i[1]]))[0] + real_div_node = graph_info[real_div_content_node_name].node + real_div_tensor = tensor_util.MakeNdarray(real_div_node.attr["value"].tensor) + + mul_input_names = list(graph_info[i[3]].node.input) + mul_content_node_name = list(set(mul_input_names).difference([i[2]]))[0] + mul_content_node = graph_info[mul_content_node_name].node + mul_tensor = tensor_util.MakeNdarray(mul_content_node.attr["value"].tensor) + + bias_input_names = list(graph_info[i[4]].node.input) + bias_content_node_name = list(set(bias_input_names).difference([i[3]]))[0] + bias_content_node = graph_info[bias_content_node_name].node + bias_tensor = tensor_util.MakeNdarray(bias_content_node.attr["value"].tensor) + + bias_offset_value = bias_tensor - sub_tensor * mul_tensor / real_div_tensor + weights_offset = mul_tensor / real_div_tensor + + weights = Helper.values_from_const(weights_node) + original_shape = weights.shape + tmp_shape = (original_shape[-1], int(weights.size / original_shape[-1])) + tmp_order = [weights.ndim - 1] + [i for i in range(weights.ndim - 1)] + + scaled_weights = np.copy(weights).transpose(tmp_order).ravel().reshape(tmp_shape) + reshape_scale = np.array(weights_offset).reshape(len(weights_offset), 1) + scaled_weights = np.multiply(scaled_weights, reshape_scale).transpose().reshape(original_shape) + scaled_weight_name = weights_node_name + "_conv_math_offset" + scaled_weights_node = Helper.create_constant_node( + scaled_weight_name, scaled_weights, dtypes.float32, shape=weights.shape + ) + + g.add_node(scaled_weights_node, None, [i[0]]) + g.replace_const_node(scaled_weights_node, [i[0]], weights_node_name) + + offset_node = Helper.create_constant_node(i[0] + "_biasadd_math_offset", bias_offset_value, dtypes.float32) + g.add_node(offset_node, None, [i[4]]) + graph_info[i[4]].node.input[0] = i[0] + + graph_info[i[4]].node.input[1] = offset_node.name + + g.remove_node(i[1]) + g.remove_node(sub_content_node_name) + + g.remove_node(i[2]) + g.remove_node(real_div_content_node_name) + + g.remove_node(i[3]) + g.remove_node(mul_content_node_name) + + g.remove_node(bias_content_node_name) + + return g.dump_graph() diff --git a/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/fuse_decomposed_bn.py b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/fuse_decomposed_bn.py new file mode 100644 index 00000000000..c6347af00a6 --- /dev/null +++ b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/fuse_decomposed_bn.py @@ -0,0 +1,400 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2022 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fuse Decomposed BatchNorm Graph Rewriter.""" + +import collections +import math +import re + +import numpy as np +from tensorflow.compat.v1 import graph_util +from tensorflow.core.framework import attr_value_pb2, graph_pb2, node_def_pb2 +from tensorflow.python.framework import dtypes, tensor_util +from tensorflow.python.platform import flags as flags_lib +from tensorflow.python.platform import tf_logging +from tensorflow.python.tools import strip_unused_lib + +from neural_compressor.tensorflow.utils import dump_elapsed_time + + +class FuseDecomposedBNOptimizer: + """Fuse decomposed small ops to BatchNormalization.""" + + def __init__(self, input_graph_def): + """Initialization.""" + self.input_graph_def = input_graph_def + + @dump_elapsed_time("Pass FuseDecomposedBNOptimizer") + def do_transformation(self): + """Fuse individual ops in batch normalization to FusedBatchNorm. + + In some models, the batch normalizatin is performed via a group of individual + ops instead of using single FusedBatchNorm op. This function identifies a + pattern of batch normalization subgraph which is made of multiple ops and + transforms the graph by replacing those individual ops with FusedBatchNorm op. + This will provide the opportunity to further fold the FusedBatchNorm with + convolution ops to reduce the computation steps during inference. + This function currently recognizes batch normalization patterns described + below, this could be extended if newer patterns are seen. Also, the fusion + is only attempted if the input graph is in NHWC format or has no format set. + Computation function: + (X * multiplier) + (Beta - Mean * multiplier) + where multiplier = rsqrt (Variance + Epsilon) * Gamma + OR = rsqrt (Variance + Epsilon) when Gamma is 1 + Subgraph: + {"Add" + {{"Mul" // mul_0 + {{"*"}, // input to apply batchnorm + {"Mul" // mul_1, same op is used inside the Sub block + {{"Rsqrt" + {"Add" + {{"Const" | "Reshape(Const)"}, // Variance + {"Const"} // Epsilon + } + } + }, // end - Rsqrt + {"Const" | "Reshape(Const)"} // Gamma + } + } // end - mul_1 + } + }, // end - mul_0 + {"Sub" + {{"Const" | "Reshape(Const)"}, // Beta + {"Mul" // mul_3 + {{"Const" | "Reshape(Const)"}, // Mean + {"Mul" // same mul_1 op as in previous block + {{"Rsqrt" + {"Add" + {{"Const" | "Reshape(Const)"}, // Variance + {"Const"} // Epsilon + } + } + }, // end - Rsqrt + {"Const" | "Reshape(Const)"} // Gamma + } + } // end - mul_1 + } + } // end - mul_3 + } + } // end - Sub + } + } // end - Add + Subgraph pattern when gamma value is 1 and the gamma scaling Mul is skipped + {"Add" + {{"Mul" // mul_0 + {{"*"}, // input to apply batchnorma + {"Rsqrt" // same Rsqrt op used in Sub block + {"Add" + {{"Const" | "Reshape(Const)"}, // Variance + {"Const"} // Epsilon + } + } + } // end - Rsqrt + } + }, // end - mul_0 + {"Sub" + {{"Const" | "Reshape(Const)"}, // Beta + {"Mul" // mul_1 + {{"Const" | "Reshape(Const)"}, // Mean + {"Rsqrt" // same Rsqrt op as in previous mul_0 block + {"Add" + {{"Const" | "Reshape(Const)"}, // Variance + {"Const"} // Epsilon + } + } + } // end - Rsqrt + } + } // end - mul_1 + } + } // end - Sub + } + } // end - Add + Args: + input_graph_def: A GraphDef containing a model. + + Returns: + Modified graph with individual ops that made up of batch normalization + fused to FusedBatchNorm. + + Raises: + ValueError: If the graph is badly formed with duplicate node names. + """ + input_node_map = {} + for node in self.input_graph_def.node: + if node.name not in input_node_map: + input_node_map[node.name] = node + else: + raise ValueError("Duplicate node names detected for ", node.name) + + # Check format and only proceed if graph is in NHWC or has no format set. + data_format = None + for node in self.input_graph_def.node: + if "data_format" in node.attr.keys(): + data_format = node.attr["data_format"] + if data_format is not None and data_format.s != b"NHWC": + tf_logging.warn("%s in %s format, not candidate for batchnorm fusion." % (node.name, data_format.s)) + return self.input_graph_def + else: + continue + + nodes_to_skip = {} + new_ops = [] + for node in self.input_graph_def.node: + if node.op != "Add": + continue + + # Add (Mul, Sub) or Add (Sub, Mul) + input0_op = node_from_map(input_node_map, node.input[0]) + input1_op = node_from_map(input_node_map, node.input[1]) + + if input0_op.op == "Mul" and input1_op.op == "Sub": + data_scale_mul_op = input0_op + bias_mean_sub_op = input1_op + elif input0_op.op == "Sub" and input1_op.op == "Mul": + bias_mean_sub_op = input0_op + data_scale_mul_op = input1_op + else: + continue + + # Mul (input, Mul) + input_data_op = node_from_map(input_node_map, data_scale_mul_op.input[0]) + # Workaround for model ava-person-vehicle-detection-stage2-2_0_0 + # FusedBatchNorm requires a 4D Tensor for input data, + # but the MatMul before FusedBatchNorm only support 2D output. + # Don't fuse the small ops to FusedBatchNorm when the upstream has MatMul. + if input_data_op.op == "MatMul": + continue + + # Workaround for DIEN_Deep-Interest-Evolution-Network + if input_data_op.op == "ConcatV2" and input_data_op.name == "concat_8": + continue + + if input_data_op.input: + ancestor_input_data_op = node_from_map(input_node_map, input_data_op.input[0]) + if ancestor_input_data_op.op == "MatMul": + continue + + scale_op = node_from_map(input_node_map, data_scale_mul_op.input[1]) + + if scale_op.op == "Rsqrt": + gamma_op = None + gamma_reshape_op = None + rsqrt_op = scale_op + elif scale_op.op == "Mul": + # Mul (Rsqrt, Constant_gamma) + rsqrt_op = node_from_map(input_node_map, scale_op.input[0]) + gamma_op, gamma_reshape_op = bypass_reshape(input_node_map, scale_op.input[1]) + if rsqrt_op.op != "Rsqrt": + continue + if gamma_op.op != "Const" or get_const_dim_count(gamma_op) != 1: + continue + else: + continue + + # Sub (Constant_beta, Mul) + beta_op, beta_reshape_op = bypass_reshape(input_node_map, bias_mean_sub_op.input[0]) + mean_scale_mul_op = node_from_map(input_node_map, bias_mean_sub_op.input[1]) + if mean_scale_mul_op.op != "Mul": + continue + if beta_op.op != "Const" or get_const_dim_count(beta_op) != 1: + continue + + # Common scale applies to both input and running mean + if scale_op != node_from_map(input_node_map, mean_scale_mul_op.input[1]): + continue + + mean_op, mean_reshape_op = bypass_reshape(input_node_map, mean_scale_mul_op.input[0]) + if mean_op.op != "Const" or get_const_dim_count(mean_op) != 1: + continue + + # Add (Constant_variance, Constant_epsilon) + variance_epsilon_add_op = node_from_map(input_node_map, rsqrt_op.input[0]) + if variance_epsilon_add_op.op != "Add": + continue + + variance_op, variance_reshape_op = bypass_reshape(input_node_map, variance_epsilon_add_op.input[0]) + epsilon_op = node_from_map(input_node_map, variance_epsilon_add_op.input[1]) + if epsilon_op.op != "Const" or get_const_dim_count(epsilon_op) != 0: + continue + if variance_op.op != "Const" or get_const_dim_count(variance_op) != 1: + continue + + epsilon = values_from_const(epsilon_op) + + nodes_to_skip[node.name] = True + nodes_to_skip[data_scale_mul_op.name] = True + nodes_to_skip[bias_mean_sub_op.name] = True + nodes_to_skip[mean_scale_mul_op.name] = True + nodes_to_skip[scale_op.name] = True + if scale_op.op != "Rsqrt": + nodes_to_skip[rsqrt_op.name] = True + nodes_to_skip[variance_epsilon_add_op.name] = True + if gamma_reshape_op is not None: + nodes_to_skip[gamma_reshape_op.name] = True + if beta_reshape_op is not None: + nodes_to_skip[beta_reshape_op.name] = True + if mean_reshape_op is not None: + nodes_to_skip[mean_reshape_op.name] = True + if variance_reshape_op is not None: + nodes_to_skip[variance_reshape_op.name] = True + + if gamma_op is None: + gamma_op = node_def_pb2.NodeDef() + gamma_op.op = "Const" + # Assign name with same root of Rsqrt op's name plus "gamma" + m = re.search(r"(.*)/(.*)", scale_op.name) + if m: + gamma_op.name = m.group(1) + "/gamma" + else: + gamma_op.name = scale_op.name + "/gamma" + gamma_op.attr["dtype"].CopyFrom(beta_op.attr["dtype"]) + beta_value = values_from_const(beta_op) + gamma_op.attr["value"].CopyFrom( + attr_value_pb2.AttrValue( + tensor=tensor_util.make_tensor_proto( + 1, beta_value.dtype.type, beta_value.shape, allow_broadcast=True + ) + ) + ) + new_ops.append(gamma_op) + + new_fused_batchnorm_op = node_def_pb2.NodeDef() + new_fused_batchnorm_op.op = "FusedBatchNorm" + new_fused_batchnorm_op.name = node.name + new_fused_batchnorm_op.attr["T"].CopyFrom(node.attr["T"]) + new_fused_batchnorm_op.attr["is_training"].CopyFrom(attr_value_pb2.AttrValue(b=False)) + new_fused_batchnorm_op.attr["epsilon"].CopyFrom(attr_value_pb2.AttrValue(f=epsilon.tolist())) + if data_format is not None: + new_fused_batchnorm_op.attr["data_format"].CopyFrom(data_format) + new_fused_batchnorm_op.input.extend( + [input_data_op.name, gamma_op.name, beta_op.name, mean_op.name, variance_op.name] + ) + + new_ops.append(new_fused_batchnorm_op) + + result_graph_def = graph_pb2.GraphDef() + for node in self.input_graph_def.node: + if node.name in nodes_to_skip: + continue + new_node = node_def_pb2.NodeDef() + new_node.CopyFrom(node) + retained_input = [] + for input_node in new_node.input: + if not input_node.startswith("^") or input_node[1:] not in nodes_to_skip: + retained_input.append(input_node) + new_node.input[:] = retained_input + result_graph_def.node.append(new_node) + + result_graph_def.node.extend(new_ops) + result_graph_def.versions.CopyFrom(self.input_graph_def.versions) + return result_graph_def + + +def node_name_from_input(node_name): + """Strips off ports and other decorations to get the underlying node name.""" + if node_name.startswith("^"): + node_name = node_name[1:] + m = re.search(r"(.*):\d+$", node_name) + if m: + node_name = m.group(1) + return node_name + + +def node_from_map(node_map, name): + """Pulls a node def from a dictionary for a given name. + + Args: + node_map: Dictionary containing an entry indexed by name for every node. + name: Identifies the node we want to find. + + Returns: + NodeDef of the node with the given name. + + Raises: + ValueError: If the node isn't present in the dictionary. + """ + stripped_name = node_name_from_input(name) + if stripped_name not in node_map: + raise ValueError("No node named '%s' found in map." % name) + return node_map[stripped_name] + + +def values_from_const(node_def): + """Extracts the values from a const NodeDef as a numpy ndarray. + + Args: + node_def: Const NodeDef that has the values we want to access. + + Returns: + Numpy ndarray containing the values. + + Raises: + ValueError: If the node isn't a Const. + """ + if node_def.op != "Const": + raise ValueError("Can not extract constant value from a node that is not Const. Got:\n" f"{node_def}") + input_tensor = node_def.attr["value"].tensor + tensor_value = tensor_util.MakeNdarray(input_tensor) + return tensor_value + + +def valid_reshape_inputs(reshape_in0_ndef, reshape_in1_ndef): + """Check if the inputs of the Reshape are valid.""" + if reshape_in0_ndef.op != "Const" or reshape_in1_ndef.op != "Const" or get_const_dim_count(reshape_in0_ndef) != 1: + return False + input0_vec_size = values_from_const(reshape_in0_ndef).shape[0] + const_value = values_from_const(reshape_in1_ndef) + shape_ndims = const_value.ndim + if shape_ndims != 1: + raise ValueError("Num of dims of the shape must be 1, got {}.".format(shape_ndims)) + for value in const_value.tolist()[:-1]: + if value != 1: + return False + if const_value.tolist()[-1] != input0_vec_size: + return False + return True + + +def bypass_reshape(input_node_map, input_name): + """Get Reshape input nodes.""" + reshape_ndef = None + maybe_reshape_ndef = node_from_map(input_node_map, input_name) + input_ndef = maybe_reshape_ndef + if maybe_reshape_ndef.op == "Reshape": + reshpae_input0_ndef = node_from_map(input_node_map, maybe_reshape_ndef.input[0]) + reshpae_input1_ndef = node_from_map(input_node_map, maybe_reshape_ndef.input[1]) + if ( + reshpae_input0_ndef.op == "Const" + and reshpae_input1_ndef.op == "Const" + and valid_reshape_inputs(reshpae_input0_ndef, reshpae_input1_ndef) + ): + input_ndef = reshpae_input0_ndef + reshape_ndef = maybe_reshape_ndef + return input_ndef, reshape_ndef + + +def get_const_dim_count(node_def): + """Get the number of dimensions for a Const node. + + Args: + node_def: Const NodeDef. + + Returns: + Number of dimensions for the Const node. + """ + const_value = values_from_const(node_def) + return const_value.ndim diff --git a/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/fuse_decomposed_in.py b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/fuse_decomposed_in.py new file mode 100644 index 00000000000..be15a745b5b --- /dev/null +++ b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/fuse_decomposed_in.py @@ -0,0 +1,369 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2022 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fuse Decomposed InstanceNorm Graph Rewriter.""" + +import re + +from tensorflow.core.framework import attr_value_pb2, graph_pb2, node_def_pb2 +from tensorflow.python.framework import dtypes, tensor_util +from tensorflow.python.platform import tf_logging + +from neural_compressor.tensorflow.quantization.utils.quantize_graph_common import QuantizeGraphHelper as helper +from neural_compressor.tensorflow.utils import dump_elapsed_time + + +class FuseDecomposedINOptimizer: # pragma: no cover + """Fuse decomposed small ops into InstanceNorm.""" + + def __init__(self, input_graph_def): + """Initialization.""" + self.input_graph_def = input_graph_def + + @dump_elapsed_time("Pass FuseDecomposedINOptimizer") + def do_transformation(self): + """Find a group of ops that make up an instance/layer normalization pattern for fusion. + + In some models, the instance normalizatin is performed via a group of individual + ops instead of using single InstanceNormalization op. This function identifies a + pattern of instance normalization subgraph which is made of multiple ops and + transforms the graph by replacing those individual ops with InstanceNormalization op. + This will provide the opportunity to further fold the InstanceNormalization with + convolution ops to reduce the computation steps during inference. + This function currently recognizes instance normalization patterns described + below, this could be extended if newer patterns are seen. Also, the fusion + is only attempted if the input graph is in NHWC format or has no format set. + + The following pattern will be searched in the graph with additional + constraints. Here * means any type of op. + clang-format off + Subgraph for fusion + ------------------- + *(input) + x x x____________ + x x x + x x Mean1 FusedOp + x x x x ------- + x x x x *(input) Const Const + x x x x x (gamma) (beta) + x x x x x x x + x x x x _MklFusedInstanceNorm/_MklLayerNorm + x x x x + x SquaredDiff Const x + x x x x + x x x x + x Mean0 Const x + x x x x + x AddV2|Add x + x x Const x + x Rsqrt (gamma) x + x x x x + x Mul1 x + x x x x + x x x x + x x x x + x x Constx x + x x (beta)Mul2 + x x x x + Mul0 Sub + x x + AddV2|Add(output) + Args: + input_graph_def: A GraphDef containing a model. + + Returns: + Modified graph with individual ops that made up of instance normalization + fused to InstanceNormalization. + + Raises: + ValueError: If the graph is badly formed with duplicate node names. + """ + input_node_map = {} + for node in self.input_graph_def.node: + if node.name not in input_node_map: + input_node_map[node.name] = node + else: + raise ValueError("Duplicate node names detected for ", node.name) + + nodes_to_skip = {} + new_ops = [] + for node in self.input_graph_def.node: + if node.op != "Add" and node.op != "AddV2": + continue + + # Add (Mul0, Sub) or Add (Sub, Mul0) + input0_op = node_from_map(input_node_map, node.input[0]) + input1_op = node_from_map(input_node_map, node.input[1]) + + if input0_op.op == "Mul" and input1_op.op == "Sub": + data_scale_mul_op = input0_op + bias_mean_sub_op = input1_op + elif input0_op.op == "Sub" and input1_op.op == "Mul": + bias_mean_sub_op = input0_op + data_scale_mul_op = input1_op + else: + continue + + # Mul0 (*input, Mul1) + input_data_op = node_from_map(input_node_map, data_scale_mul_op.input[0]) + scale_op = node_from_map(input_node_map, data_scale_mul_op.input[1]) + + # Mul1 (Rsqrt, Constant_gamma) + if scale_op.op == "Mul": + rsqrt_op = node_from_map(input_node_map, scale_op.input[0]) + gamma_op, gamma_reshape_op = bypass_reshape(input_node_map, scale_op.input[1]) + if rsqrt_op.op != "Rsqrt": + continue + if gamma_op.op != "Const": + continue + else: + continue + + # Sub (Constant_beta, Mul2) + beta_op, beta_reshape_op = bypass_reshape(input_node_map, bias_mean_sub_op.input[0]) + mean_scale_mul_op = node_from_map(input_node_map, bias_mean_sub_op.input[1]) + if mean_scale_mul_op.op != "Mul": + continue + if beta_op.op != "Const": + continue + + # Common scale applies to both input and running mean + if scale_op != node_from_map(input_node_map, mean_scale_mul_op.input[1]): + continue + + mean_op, mean_reshape_op = bypass_reshape(input_node_map, mean_scale_mul_op.input[0]) + if mean_op.op != "Mean": + continue + + # Add (variance-mean0, Constant_epsilon) + variance_epsilon_add_op = node_from_map(input_node_map, rsqrt_op.input[0]) + if variance_epsilon_add_op.op != "Add" and variance_epsilon_add_op.op != "AddV2": + continue + + variance_op, variance_reshape_op = bypass_reshape(input_node_map, variance_epsilon_add_op.input[0]) + epsilon_op = node_from_map(input_node_map, variance_epsilon_add_op.input[1]) + if epsilon_op.op != "Const": + continue + if variance_op.op != "Mean": + continue + + epsilon = values_from_const(epsilon_op) + + # Mean (SquaredDifference, Constant_r_indices0) + squared_diff_op, squared_reshape_op = bypass_reshape(input_node_map, variance_op.input[0]) + r_indices0_op = node_from_map(input_node_map, variance_op.input[1]) + if squared_diff_op.op != "SquaredDifference": + continue + if r_indices0_op.op != "Const": + continue + + if input_data_op != node_from_map(input_node_map, squared_diff_op.input[0]): + continue + + if mean_op != node_from_map(input_node_map, squared_diff_op.input[1]): + continue + + if input_data_op != node_from_map(input_node_map, mean_op.input[0]): + continue + + r_indices1_op = node_from_map(input_node_map, mean_op.input[1]) + if r_indices1_op.op != "Const": + continue + + r_indices1 = values_from_const(r_indices1_op) + if ( + r_indices1.tolist() != [1, 2] + and r_indices1.tolist() != [2, 3] + and r_indices1.tolist() != [1, 2, 3] + and r_indices1.tolist() != [2, 3, 4] + ): + continue + + nodes_to_skip[node.name] = True + nodes_to_skip[data_scale_mul_op.name] = True + nodes_to_skip[bias_mean_sub_op.name] = True + nodes_to_skip[mean_scale_mul_op.name] = True + nodes_to_skip[scale_op.name] = True + nodes_to_skip[rsqrt_op.name] = True + nodes_to_skip[variance_epsilon_add_op.name] = True + nodes_to_skip[squared_diff_op.name] = True + nodes_to_skip[mean_op.name] = True + nodes_to_skip[variance_op.name] = True + if gamma_reshape_op is not None: + nodes_to_skip[gamma_reshape_op.name] = True + if beta_reshape_op is not None: + nodes_to_skip[beta_reshape_op.name] = True + if mean_reshape_op is not None: + nodes_to_skip[mean_reshape_op.name] = True + if variance_reshape_op is not None: + nodes_to_skip[variance_reshape_op.name] = True + + if gamma_op is None: + gamma_op = node_def_pb2.NodeDef() + gamma_op.op = "Const" + # Assign name with same root of Rsqrt op's name plus "gamma" + m = re.search(r"(.*)/(.*)", scale_op.name) + if m: + gamma_op.name = m.group(1) + "/gamma" + else: + gamma_op.name = scale_op.name + "/gamma" + gamma_op.attr["dtype"].CopyFrom(beta_op.attr["dtype"]) + beta_value = values_from_const(beta_op) + gamma_op.attr["value"].CopyFrom( + attr_value_pb2.AttrValue( + tensor=tensor_util.make_tensor_proto( + 1, beta_value.dtype.type, beta_value.shape, allow_broadcast=True + ) + ) + ) + new_ops.append(gamma_op) + + new_fused_instancenorm_op = node_def_pb2.NodeDef() + new_fused_instancenorm_op.op = "_MklFusedInstanceNorm" + new_fused_instancenorm_op.name = node.name + new_fused_instancenorm_op.attr["T"].CopyFrom(node.attr["T"]) + new_fused_instancenorm_op.attr["epsilon"].CopyFrom(attr_value_pb2.AttrValue(f=epsilon.tolist())) + list_value = attr_value_pb2.AttrValue.ListValue(i=r_indices1.flatten()) + new_fused_instancenorm_op.attr["reduction_axes"].CopyFrom(attr_value_pb2.AttrValue(list=list_value)) + + # Mean and variance values will be computed at runtime for fp32 & bf16 input. + # Pass a "dummy" node for mean and variance. + mean_variance_dim = tensor_util.MakeNdarray(gamma_op.attr["value"].tensor).shape[-1] + dummy_mean_node = helper.create_constant_node( + node.name + "_dummy_mean", [0.0] * mean_variance_dim, dtypes.float32 + ) + dummy_variance_node = helper.create_constant_node( + node.name + "_dummy_variance", [1.0] * mean_variance_dim, dtypes.float32 + ) + new_fused_instancenorm_op.input.extend( + [input_data_op.name, gamma_op.name, beta_op.name, dummy_mean_node.name, dummy_variance_node.name] + ) + new_ops.append(dummy_mean_node) + new_ops.append(dummy_variance_node) + new_ops.append(new_fused_instancenorm_op) + + result_graph_def = graph_pb2.GraphDef() + for node in self.input_graph_def.node: + if node.name in nodes_to_skip: + continue + new_node = node_def_pb2.NodeDef() + new_node.CopyFrom(node) + retained_input = [] + for input_node in new_node.input: + if not input_node.startswith("^") or input_node[1:] not in nodes_to_skip: + retained_input.append(input_node) + new_node.input[:] = retained_input + result_graph_def.node.append(new_node) + + result_graph_def.node.extend(new_ops) + result_graph_def.versions.CopyFrom(self.input_graph_def.versions) + return result_graph_def + + +def node_name_from_input(node_name): + """Strips off ports and other decorations to get the underlying node name.""" + if node_name.startswith("^"): + node_name = node_name[1:] + m = re.search(r"(.*):\d+$", node_name) + if m: + node_name = m.group(1) + return node_name + + +def node_from_map(node_map, name): + """Pulls a node def from a dictionary for a given name. + + Args: + node_map: Dictionary containing an entry indexed by name for every node. + name: Identifies the node we want to find. + + Returns: + NodeDef of the node with the given name. + + Raises: + ValueError: If the node isn't present in the dictionary. + """ + stripped_name = node_name_from_input(name) + if stripped_name not in node_map: + raise ValueError("No node named '%s' found in map." % name) + return node_map[stripped_name] + + +def values_from_const(node_def): + """Extracts the values from a const NodeDef as a numpy ndarray. + + Args: + node_def: Const NodeDef that has the values we want to access. + + Returns: + Numpy ndarray containing the values. + + Raises: + ValueError: If the node isn't a Const. + """ + if node_def.op != "Const": + raise ValueError("Can not extract constant value from a node that is not Const. Got:\n" f"{node_def}") + input_tensor = node_def.attr["value"].tensor + tensor_value = tensor_util.MakeNdarray(input_tensor) + return tensor_value + + +def valid_reshape_inputs(reshape_in0_ndef, reshape_in1_ndef): + """Check if the inputs of the Reshape are valid.""" + if reshape_in0_ndef.op != "Const" or reshape_in1_ndef.op != "Const" or get_const_dim_count(reshape_in0_ndef) != 1: + return False + input0_vec_size = values_from_const(reshape_in0_ndef).shape[0] + const_value = values_from_const(reshape_in1_ndef) + shape_ndims = const_value.ndim + if shape_ndims != 1: + raise ValueError("Num of dims of the shape must be 1, got {}.".format(shape_ndims)) + for value in const_value.tolist()[:-1]: + if value != 1: + return False + if const_value.tolist()[-1] != input0_vec_size: + return False + return True + + +def bypass_reshape(input_node_map, input_name): + """Get Reshape input nodes.""" + reshape_ndef = None + maybe_reshape_ndef = node_from_map(input_node_map, input_name) + input_ndef = maybe_reshape_ndef + if maybe_reshape_ndef.op == "Reshape": + reshpae_input0_ndef = node_from_map(input_node_map, maybe_reshape_ndef.input[0]) + reshpae_input1_ndef = node_from_map(input_node_map, maybe_reshape_ndef.input[1]) + if ( + reshpae_input0_ndef.op == "Const" + and reshpae_input1_ndef.op == "Const" + and valid_reshape_inputs(reshpae_input0_ndef, reshpae_input1_ndef) + ): + input_ndef = reshpae_input0_ndef + reshape_ndef = maybe_reshape_ndef + return input_ndef, reshape_ndef + + +def get_const_dim_count(node_def): + """Get the number of dimensions for a Const node. + + Args: + node_def: Const NodeDef. + + Returns: + Number of dimensions for the Const node. + """ + const_value = values_from_const(node_def) + return const_value.ndim diff --git a/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/fuse_gelu.py b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/fuse_gelu.py new file mode 100644 index 00000000000..4c1984138ab --- /dev/null +++ b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/fuse_gelu.py @@ -0,0 +1,228 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fuse small ops to Gelu Graph Rewriter.""" + +import tensorflow as tf +from tensorflow.python.framework import dtypes + +from neural_compressor.tensorflow.quantization.utils.graph_rewriter.graph_base import GraphRewriterBase +from neural_compressor.tensorflow.quantization.utils.graph_util import GraphAnalyzer +from neural_compressor.tensorflow.quantization.utils.graph_util import GraphRewriterHelper as Helper +from neural_compressor.tensorflow.utils import SPR_BASE_VERSIONS + + +class FuseGeluOptimizer(GraphRewriterBase): # pragma: no cover + """Fuse Sqrt + RealDiv + Erf + AddV2 + Mul + Mul into Gelu op.""" + + def do_transformation(self): + """Execute the fusion from small ops to Gelu.""" + if not (tf.version.VERSION in ("1.15.0-up2", "1.15.0-up3") or tf.version.VERSION in SPR_BASE_VERSIONS): + return self.model + + cur_graph = GraphAnalyzer() + cur_graph.graph = self.model + + graph_info = cur_graph.parse_graph() + # Below code is relative to expression on + # https://github.com/IntelAI/models/blob/master/models/language_modeling/tensorflow/ + # bert_large/inference/generic_ops.py#L105 + target_nodes = cur_graph.query_fusion_pattern_nodes( + [["Pow"], ["Mul"], ["AddV2"], ["Mul"], ["Tanh"], ["AddV2"], ["Mul"], ["Mul"]] + ) + + if not target_nodes: + target_nodes = cur_graph.query_fusion_pattern_nodes( + [["Pow"], ["Mul"], ["AddV2"], ["Mul"], ["Tanh"], ["AddV2"], ["Mul"]] + ) + + for node_combination in target_nodes: + match_node_length = len(node_combination) + pow_node = graph_info[node_combination[0]].node + mul_1_node = graph_info[node_combination[1]].node + addv2_1_node = graph_info[node_combination[2]].node + mul_2_node = graph_info[node_combination[3]].node + tanh_node = graph_info[node_combination[4]].node + addv2_2_node = graph_info[node_combination[5]].node + if match_node_length == 8: + mul_3_node = graph_info[node_combination[6]].node + else: + mul_3_node = graph_info[node_combination[7]].node + + gelu_input_name = None + pow_const_node_name = None + pow_value = None + + for i in pow_node.input: + node_name = Helper.node_name_from_input(i) + if graph_info[node_name].node.op != "Const": + gelu_input_name = i + + if graph_info[node_name].node.op == "Const": + pow_const_node_name = i + pow_value = graph_info[node_name].node.attr["value"].tensor.float_val[0] + break + + if pow_value != 3: + continue + mul_1_value = None + mul_1_const_node_name = None + for i in mul_1_node.input: + i = Helper.node_name_from_input(i) + if i != pow_node.name and graph_info[i].node.op == "Const": + mul_1_const_node_name = i + mul_1_value = graph_info[i].node.attr["value"].tensor.float_val[0] + break + if mul_1_value != 0.044714998453855515: + continue + + mul_2_value = None + mul_2_const_node_name = None + for i in mul_2_node.input: + i = Helper.node_name_from_input(i) + if i != addv2_1_node.name and graph_info[i].node.op == "Const": + mul_2_const_node_name = i + mul_2_value = graph_info[i].node.attr["value"].tensor.float_val[0] + break + if mul_2_value != 0.7978845834732056: + continue + + addv2_2_value = None + addv2_2_const_node_name = None + for i in addv2_2_node.input: + i = Helper.node_name_from_input(i) + if i != tanh_node.name and graph_info[i].node.op == "Const": + addv2_2_const_node_name = i + addv2_2_value = graph_info[i].node.attr["value"].tensor.float_val[0] + break + if addv2_2_value != 1: + continue + + rest_mul_node = None + if match_node_length == 8: + for i in mul_3_node.input: + i = Helper.node_name_from_input(i) + if i != addv2_2_node.name: + rest_mul_node = graph_info[i].node + break + + if not rest_mul_node or rest_mul_node.op != "Mul": + continue + else: + rest_mul_node = graph_info[node_combination[6]].node + + rest_mul_valid = False + rest_mul_const_node_name = None + for i in rest_mul_node.input: + i = Helper.node_name_from_input(i) + if graph_info[i].node.op == "Const" and graph_info[i].node.attr["value"].tensor.float_val[0] == 0.5: + rest_mul_const_node_name = i + rest_mul_valid = True + break + + if not rest_mul_valid: + continue + + cur_graph.remove_node(pow_const_node_name) + cur_graph.remove_node(pow_node.name) + cur_graph.remove_node(mul_1_node.name) + cur_graph.remove_node(mul_1_const_node_name) + cur_graph.remove_node(addv2_1_node.name) + cur_graph.remove_node(mul_2_node.name) + cur_graph.remove_node(mul_2_const_node_name) + cur_graph.remove_node(tanh_node.name) + cur_graph.remove_node(addv2_2_node.name) + cur_graph.remove_node(addv2_2_const_node_name) + cur_graph.remove_node(rest_mul_node.name) + cur_graph.remove_node(rest_mul_const_node_name) + + original_last = graph_info[mul_3_node.name].outputs + cur_graph.remove_node(mul_3_node.name) + gelu_node = Helper.create_node("Gelu", mul_3_node.name, [gelu_input_name]) + Helper.set_attr_bool(gelu_node, "approximate", True) + Helper.set_attr_dtype(gelu_node, "T", dtypes.float32) + + cur_graph.add_node(gelu_node, gelu_input_name, original_last) + + target_nodes = cur_graph.query_fusion_pattern_nodes( + [["Sqrt"], ["RealDiv"], ["Erf"], ["AddV2"], ["Mul"], ["Mul"]] + ) + + for node_combination in target_nodes: + sqrt_node = graph_info[node_combination[0]].node + realdiv_node = graph_info[node_combination[1]].node + erf_node = graph_info[node_combination[2]].node + addv2_node = graph_info[node_combination[3]].node + mul1_node = graph_info[node_combination[4]].node + mul2_node = graph_info[node_combination[5]].node + + sqrt_input_name = Helper.node_name_from_input(sqrt_node.input[0]) + sqrt_value = graph_info[sqrt_input_name].node.attr["value"].tensor.float_val[0] + + if sqrt_value != 2: + continue + + addv2_value = None + mul1_value = None + gelu_input_name = None + + for i in realdiv_node.input: + i = Helper.node_name_from_input(i) + if i != sqrt_node.name: + gelu_input_name = i + break + + addv2_const_name = None + for i in addv2_node.input: + i = Helper.node_name_from_input(i) + if i != erf_node.name: + addv2_value = graph_info[i].node.attr["value"].tensor.float_val[0] + addv2_const_name = i + break + + if addv2_value != 1: + continue + + mul1_const_name = None + for i in mul1_node.input: + i = Helper.node_name_from_input(i) + if i != addv2_node.name: + mul1_value = graph_info[i].node.attr["value"].tensor.float_val[0] + mul1_const_name = i + break + + if mul1_value != 0.5: + continue + + cur_graph.remove_node(sqrt_node.input[0]) + cur_graph.remove_node(sqrt_node.name) + cur_graph.remove_node(realdiv_node.name) + cur_graph.remove_node(erf_node.name) + cur_graph.remove_node(addv2_node.name) + cur_graph.remove_node(mul1_node.name) + cur_graph.remove_node(addv2_const_name) + cur_graph.remove_node(sqrt_input_name) + cur_graph.remove_node(mul1_const_name) + + original_last = graph_info[mul2_node.name].outputs + cur_graph.remove_node(mul2_node.name) + gelu_node = Helper.create_node("Gelu", mul2_node.name, [gelu_input_name]) + Helper.set_attr_bool(gelu_node, "approximate", False) + Helper.set_attr_dtype(gelu_node, "T", dtypes.float32) + + cur_graph.add_node(gelu_node, gelu_input_name, original_last) + + return cur_graph.dump_graph() diff --git a/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/fuse_layer_norm.py b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/fuse_layer_norm.py new file mode 100644 index 00000000000..4ad5cece7c8 --- /dev/null +++ b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/fuse_layer_norm.py @@ -0,0 +1,244 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2022 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fuse small ops to LayerNorm Graph Rewriter.""" + +import re + +from tensorflow.core.framework import attr_value_pb2, graph_pb2, node_def_pb2 +from tensorflow.python.framework import tensor_util + +from neural_compressor.tensorflow.utils import dump_elapsed_time + + +class FuseLayerNormOptimizer: # pragma: no cover + """Remap smaller ops into fused LayerNorm. + + Current fusion is only for the case, when LayerNormalization uses FusedBatcNormV3. + And further restrict it to only 2D or 3D tensor inputs to keras LayerNormalization api. + """ + + def __init__(self, input_graph_def): + """Constructor.""" + self.input_graph_def = input_graph_def + + @dump_elapsed_time("Pass FuseLayerNormOptimizer") + def do_transformation(self): + """The following pattern will be searched in the graph with additional constraints. + + Here * means any type of op. + Subgraph: + *(input) * * Const * Const FusedOp + x | x | | x Const ------- + x | x | | x Const x + Reshape Fill Fill x x *(input) *(gamma) *(beta) + x x x x x x | x + x x x x x x | x + F u s e d B a t c h N o r m V 3 _MklLayerNorm + x + x * + x x + Reshape + x *(gamma) + x x + Mul + *(beta) x + x x + AddV2(output) + Args: + input_graph_def: A GraphDef containing a model. + + Returns: + Modified graph with individual ops that made up of layer normalization + fused to LayerNorm. + + Raises: + ValueError: If the graph is badly formed with duplicate node names. + """ + input_node_map = {} + for node in self.input_graph_def.node: + if node.name not in input_node_map: + input_node_map[node.name] = node + else: + raise ValueError("Duplicate node names detected for ", node.name) + + nodes_to_skip = {} + new_ops = [] + for node in self.input_graph_def.node: + if node.op != "AddV2": + continue + + # AddV2 (Mul, beta) or AddV2 (beta, Mul) + input0_op = node_from_map(input_node_map, node.input[0]) + input1_op = node_from_map(input_node_map, node.input[1]) + + if input0_op.op == "Mul": + data_scale_mul_op = input0_op + beta_op = input1_op + elif input1_op.op == "Mul": + beta_op = input0_op + data_scale_mul_op = input1_op + else: + continue + + # Mul (Reshape, *gamma) + input0_op = node_from_map(input_node_map, data_scale_mul_op.input[0]) + input1_op = node_from_map(input_node_map, data_scale_mul_op.input[1]) + if input0_op.op == "Reshape": + post_reshape_op = input0_op + gamma_op = input1_op + elif input1_op.op == "Reshape": + post_reshape_op = input1_op + gamma_op = input0_op + else: + continue + + # Reshape (FusedBatchNormV3, *post_shape) + input0_op = node_from_map(input_node_map, post_reshape_op.input[0]) + input1_op = node_from_map(input_node_map, post_reshape_op.input[1]) + if input0_op.op == "FusedBatchNormV3": + fused_batch_norm_op = input0_op + post_shape_op = input1_op + elif input1_op.op == "FusedBatchNormV3": + fused_batch_norm_op = input1_op + post_shape_op = input0_op + else: + continue + + # LayerNorm uses FusedBatchNorm in training mode. + if fused_batch_norm_op.attr["is_training"] == attr_value_pb2.AttrValue(b=False): + continue + + # FusedBatchNormV3(Reshape, Fill, Fill, Mean, Variance) + pre_reshape_op = node_from_map(input_node_map, fused_batch_norm_op.input[0]) + if pre_reshape_op.op != "Reshape": + continue + fill_scale_op = node_from_map(input_node_map, fused_batch_norm_op.input[1]) + if fill_scale_op.op != "Fill": + continue + fill_offset_op = node_from_map(input_node_map, fused_batch_norm_op.input[2]) + if fill_offset_op.op != "Fill": + continue + + # FusedBatchNorm node should have mean/variance as empty constant + mean_op = node_from_map(input_node_map, fused_batch_norm_op.input[3]) + if mean_op.op != "Const": + continue + variance_op = node_from_map(input_node_map, fused_batch_norm_op.input[4]) + if variance_op.op != "Const": + continue + mean_value = values_from_const(mean_op) + if mean_value.any(): + continue + variance_value = values_from_const(variance_op) + if variance_value.any(): + continue + + # Reshape (*input, *pre_shape) + input_op = node_from_map(input_node_map, pre_reshape_op.input[0]) + pre_shape_op = node_from_map(input_node_map, pre_reshape_op.input[1]) + + # Fill Scale(*dims_fill_scale, unit_gamma) + dims_fill_scale_op = node_from_map(input_node_map, fill_scale_op.input[0]) + unit_gamma_op = node_from_map(input_node_map, fill_scale_op.input[1]) + if unit_gamma_op.op != "Const": + continue + + # Fill Offset(*dims_fill_scale, unit_gamma) + dims_fill_offset_op = node_from_map(input_node_map, fill_offset_op.input[0]) + zero_beta_op = node_from_map(input_node_map, fill_offset_op.input[1]) + if zero_beta_op.op != "Const": + continue + + nodes_to_skip[node.name] = True + nodes_to_skip[data_scale_mul_op.name] = True + nodes_to_skip[post_reshape_op.name] = True + nodes_to_skip[fused_batch_norm_op.name] = True + nodes_to_skip[fill_scale_op.name] = True + nodes_to_skip[fill_offset_op.name] = True + + new_fused_layernorm_op = node_def_pb2.NodeDef() + new_fused_layernorm_op.op = "_MklLayerNorm" + new_fused_layernorm_op.name = node.name + new_fused_layernorm_op.attr["T"].CopyFrom(node.attr["T"]) + new_fused_layernorm_op.input.extend([input_op.name, gamma_op.name, beta_op.name]) + + new_ops.append(new_fused_layernorm_op) + + result_graph_def = graph_pb2.GraphDef() + for node in self.input_graph_def.node: + if node.name in nodes_to_skip: + continue + new_node = node_def_pb2.NodeDef() + new_node.CopyFrom(node) + retained_input = [] + for input_node in new_node.input: + if not input_node.startswith("^") or input_node[1:] not in nodes_to_skip: + retained_input.append(input_node) + new_node.input[:] = retained_input + result_graph_def.node.append(new_node) + + result_graph_def.node.extend(new_ops) + result_graph_def.versions.CopyFrom(self.input_graph_def.versions) + return result_graph_def + + +def node_name_from_input(node_name): # pragma: no cover + """Strips off ports and other decorations to get the underlying node name.""" + if node_name.startswith("^"): + node_name = node_name[1:] + m = re.search(r"(.*):\d+$", node_name) + if m: + node_name = m.group(1) + return node_name + + +def node_from_map(node_map, name): # pragma: no cover + """Pulls a node def from a dictionary for a given name. + + Args: + node_map: Dictionary containing an entry indexed by name for every node. + name: Identifies the node we want to find. + + Returns: + NodeDef of the node with the given name. + + Raises: + ValueError: If the node isn't present in the dictionary. + """ + stripped_name = node_name_from_input(name) + if stripped_name not in node_map: + raise ValueError("No node named '%s' found in map." % name) + return node_map[stripped_name] + + +def values_from_const(node_def): # pragma: no cover + """Extracts the values from a const NodeDef as a numpy ndarray. + + Args: + node_def: Const NodeDef that has the values we want to access. + + Returns: + Numpy ndarray containing the values. + + Raises: + ValueError: If the node isn't a Const. + """ + if node_def.op != "Const": + raise ValueError("Can not extract constant value from a node that is not Const. Got:\n" f"{node_def}") + input_tensor = node_def.attr["value"].tensor + tensor_value = tensor_util.MakeNdarray(input_tensor) + return tensor_value diff --git a/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/fuse_pad_with_conv.py b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/fuse_pad_with_conv.py new file mode 100644 index 00000000000..e894ee8d3cb --- /dev/null +++ b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/fuse_pad_with_conv.py @@ -0,0 +1,125 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fuse Pad into Conv Graph Rewriter.""" + +import tensorflow as tf +from tensorflow.python.framework import tensor_util + +from neural_compressor.tensorflow.quantization.utils.graph_rewriter.graph_base import GraphRewriterBase +from neural_compressor.tensorflow.quantization.utils.graph_util import GraphAnalyzer +from neural_compressor.tensorflow.quantization.utils.graph_util import GraphRewriterHelper as Helper +from neural_compressor.tensorflow.utils import version1_gt_version2 + + +class FusePadWithConv2DOptimizer(GraphRewriterBase): + """Fuse Pad op into Conv2D/DepthwiseConv2dNative/Conv3D.""" + + def __init__(self, model, excluded_op_names, inputs, cfg, new_api, itex_qdq_mode=False): + """Initialization.""" + super().__init__(model) + self.excluded_conv = excluded_op_names + self.inputs = inputs + self.cfg = cfg + self.new_api = new_api + self.itex_qdq_mode = itex_qdq_mode + + def do_transformation(self): + """Fuse Pad + Conv2D/DepthwiseConv2dNative/Conv3D --> Conv2D/DepthwiseConv2dNative/Conv3D.""" + cur_graph = GraphAnalyzer() + cur_graph.graph = self.model + + graph_info = cur_graph.parse_graph() + + target_nodes = cur_graph.query_fusion_pattern_nodes( + [["Pad"], ["Conv2D", "Conv3D", "DepthwiseConv2dNative"], ("BiasAdd", "Add", "AddV2")] + ) + + padding_tensor_dict = {} + for node_combination in target_nodes: + conv_name = node_combination[1] + + pattern = node_combination[-1] + + if conv_name not in self.cfg: + continue + + is_perchannel = self.cfg[conv_name][0] + + # Line 55 to line 65 should be removed once the TFDO enabling the single quantized + # conv2D supporting. + if len(pattern) == 2: + # TODO we need to enable single quantizedconv2d with s8 input. + if not is_perchannel and not cur_graph.has_positive_input(conv_name): + continue + # TFDO has the limitation that the single QuantizedConv2DPerchannel doesn't + # support padding_list filed. + if is_perchannel: + continue + + if conv_name in self.excluded_conv: + continue + + padding_tensor = None + pad_node = None + if node_combination[0] not in padding_tensor_dict: + pad_node = graph_info[node_combination[0]].node + if graph_info[pad_node.input[1]].node.op != "Const": + input_node = graph_info[pad_node.input[1]].node + if input_node.op == "DataFormatVecPermute": + parent_input_node = graph_info[input_node.input[0]].node + if parent_input_node.op == "Const": + padding_tensor = tensor_util.MakeNdarray(parent_input_node.attr["value"].tensor).flatten() + else: + continue + else: + continue + else: + padding_tensor = tensor_util.MakeNdarray( + graph_info[pad_node.input[1]].node.attr["value"].tensor + ).flatten() + padding_tensor_dict[node_combination[0]] = padding_tensor + else: + padding_tensor = padding_tensor_dict[node_combination[0]] + + if self.itex_qdq_mode: + enabled_pad_conv2d = bool( + tf.version.VERSION == "1.15.0-up3" or version1_gt_version2(tf.version.VERSION, "2.7") + ) + else: + enabled_pad_conv2d = bool(tf.version.VERSION == "1.15.0-up3" or self.new_api) + + if any(padding_tensor) and not enabled_pad_conv2d: # pragma: no cover + continue + + if pad_node: + if graph_info[pad_node.input[1]].node.op != "Const": + cur_graph.node_name_details[pad_node.name].node.input.remove(pad_node.input[1]) + cur_graph.remove_node_with_single_input_output(pad_node.name) + else: + cur_graph.remove_node_with_single_input_output(pad_node.name) + cur_graph.remove_node(pad_node.input[1]) + conv_node = graph_info[node_combination[1]].node + if self.itex_qdq_mode: + if any(padding_tensor) and enabled_pad_conv2d: # pragma: no cover + Helper.set_attr_string(conv_node, "padding", b"EXPLICIT") + Helper.set_attr_int_list(conv_node, "explicit_paddings", padding_tensor) + else: + Helper.set_attr_int_list(conv_node, "padding_list", padding_tensor) + if any(padding_tensor) and enabled_pad_conv2d: # pragma: no cover + Helper.set_attr_string(conv_node, "padding", b"EXPLICIT") + + return cur_graph.dump_graph() diff --git a/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/fuse_pad_with_fp32_conv.py b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/fuse_pad_with_fp32_conv.py new file mode 100644 index 00000000000..c3cf5a4b62c --- /dev/null +++ b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/fuse_pad_with_fp32_conv.py @@ -0,0 +1,128 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fuse Pad into Conv Graph Rewriter.""" + +import tensorflow as tf +from tensorflow.python.framework import tensor_util + +from neural_compressor.tensorflow.quantization.utils.graph_rewriter.graph_base import GraphRewriterBase +from neural_compressor.tensorflow.quantization.utils.graph_util import GraphAnalyzer +from neural_compressor.tensorflow.quantization.utils.graph_util import GraphRewriterHelper as Helper +from neural_compressor.tensorflow.utils import version1_gt_version2 + + +class FusePadWithFP32Conv2DOptimizer(GraphRewriterBase): + """Fuse Pad op into Conv.""" + + def __init__(self, model, excluded_op_names, inputs, cfg, new_api, itex_qdq_mode=False): + """Initialization.""" + super().__init__(model) + self.excluded_conv = excluded_op_names + self.inputs = inputs + self.cfg = cfg + self.new_api = new_api + self.itex_qdq_mode = itex_qdq_mode + + def do_transformation(self): + """Fuse Pad op into Conv2D/DepthwiseConv2dNative/Conv3D.""" + cur_graph = GraphAnalyzer() + cur_graph.graph = self.model + + graph_info = cur_graph.parse_graph() + + target_nodes = cur_graph.query_fusion_pattern_nodes( + [["Pad"], ["Conv2D", "DepthwiseConv2dNative"], ("BiasAdd", "Add", "AddV2")] + ) + + padding_tensor_dict = {} + for node_combination in target_nodes: + conv_name = node_combination[1] + + pattern = node_combination[-1] + + if conv_name not in self.cfg: + continue + + is_perchannel = self.cfg[conv_name][0] + + # Line 55 to line 65 should be removed once the TFDO enabling the single quantized + # conv2D supporting. + if len(pattern) == 2: + # TODO we need to enable single quantizedconv2d with s8 input. + if not is_perchannel and not cur_graph.has_positive_input(conv_name): + continue + # TFDO has the limitation that the single QuantizedConv2DPerchannel doesn't + # support padding_list filed. + if is_perchannel: + continue + + if conv_name in self.excluded_conv: + continue + + padding_tensor = None + pad_node = None + if node_combination[0] not in padding_tensor_dict: + pad_node = graph_info[node_combination[0]].node + if graph_info[pad_node.input[1]].node.op != "Const": + input_node = graph_info[pad_node.input[1]].node + if input_node.op == "DataFormatVecPermute": + parent_input_node = graph_info[input_node.input[0]].node + if parent_input_node.op == "Const": + padding_tensor = tensor_util.MakeNdarray(parent_input_node.attr["value"].tensor).flatten() + else: + continue + else: + continue + else: + padding_tensor = tensor_util.MakeNdarray( + graph_info[pad_node.input[1]].node.attr["value"].tensor + ).flatten() + padding_tensor_dict[node_combination[0]] = padding_tensor + else: + padding_tensor = padding_tensor_dict[node_combination[0]] + + if self.itex_qdq_mode: + enabled_pad_conv2d = bool( + tf.version.VERSION == "1.15.0-up3" or version1_gt_version2(tf.version.VERSION, "2.7") + ) + else: + enabled_pad_conv2d = bool(tf.version.VERSION == "1.15.0-up3" or self.new_api) + + if any(padding_tensor) and not enabled_pad_conv2d: # pragma: no cover + continue + + if pad_node: + if graph_info[pad_node.input[1]].node.op != "Const": + cur_graph.node_name_details[pad_node.name].node.input.remove(pad_node.input[1]) + cur_graph.remove_node_with_single_input_output(pad_node.name) + else: + cur_graph.remove_node_with_single_input_output(pad_node.name) + cur_graph.remove_node(pad_node.input[1]) + conv_node = graph_info[node_combination[1]].node + # Helper.set_attr_int_list(conv_node, "padding_list", padding_tensor) + # only when padding attr is explicit, the explicit_paddings is not empty + + if self.itex_qdq_mode: + if any(padding_tensor) and enabled_pad_conv2d: # pragma: no cover + Helper.set_attr_string(conv_node, "padding", b"EXPLICIT") + Helper.set_attr_int_list(conv_node, "explicit_paddings", padding_tensor) + else: + if any(padding_tensor) and enabled_pad_conv2d: # pragma: no cover + Helper.set_attr_string(conv_node, "padding", b"EXPLICIT") + Helper.set_attr_int_list(conv_node, "explicit_paddings", padding_tensor) + + return cur_graph.dump_graph() diff --git a/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/fuse_reshape_transpose.py b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/fuse_reshape_transpose.py new file mode 100644 index 00000000000..85ef97f8628 --- /dev/null +++ b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/fuse_reshape_transpose.py @@ -0,0 +1,126 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fuse Transpose and Reshape Graph Rewriter.""" + +from tensorflow.python.framework import dtypes, tensor_util + +from neural_compressor.tensorflow.quantization.utils.graph_util import GraphAnalyzer +from neural_compressor.tensorflow.quantization.utils.graph_util import GraphRewriterHelper as Helper +from neural_compressor.tensorflow.utils import dump_elapsed_time + +from ..graph_base import GraphRewriterBase + + +class FuseTransposeReshapeOptimizer(GraphRewriterBase): + """Fuse Transpose + Reshape + MatMul/Conv ==> MatMul/Conv.""" + + @dump_elapsed_time("Pass FuseTransposeReshapeOptimizer") + def do_transformation(self): + """Execute Transpose + Reshape + MatMul/Conv fusion.""" + g = GraphAnalyzer() + g.graph = self.model + graph_info = g.parse_graph() + + patterns = [ + ["Transpose"], + ["Reshape"], + ["MatMul", "DepthwiseConv2dNative", "Conv2D", "BatchMatMul", "BatchMatMulV2"], + ] + + matched_nodes = g.query_fusion_pattern_nodes(patterns) + + valid_match = [] + for i in matched_nodes: + transpose_input_node_name = graph_info[i[0]].node.input[0] + transpose_input_node = graph_info[transpose_input_node_name].node + if transpose_input_node.op == "Const": + valid_match.append(i) + elif transpose_input_node.op == "Enter": + if transpose_input_node.input: + enter_input_node_name = transpose_input_node.input[0] + enter_input_node = graph_info[enter_input_node_name].node + if enter_input_node.op == "Const": + valid_match.append(i) + else: + continue + + for i in valid_match: + transpose_node = graph_info[i[0]].node + transpose_input_node = graph_info[transpose_node.input[0]].node + transpose_input_perm = graph_info[transpose_node.input[1]].node + reshape_node = graph_info[i[1]].node + reshape_shape_node = graph_info[reshape_node.input[1]].node + if transpose_input_node.op == "Const": + transpose_input_node_content = tensor_util.MakeNdarray(transpose_input_node.attr["value"].tensor) + elif transpose_input_node.op == "Enter": + enter_input_node_name = transpose_input_node.input[0] + enter_input_node = graph_info[enter_input_node_name].node + transpose_input_node_content = tensor_util.MakeNdarray(enter_input_node.attr["value"].tensor) + else: + continue + + if transpose_input_perm.op == "Const": + transpose_perm_node_content = tensor_util.MakeNdarray(transpose_input_perm.attr["value"].tensor) + elif transpose_input_perm.op == "Enter": + enter_transpose_input_perm = transpose_input_perm.input[0] + enter_transpose_input_perm_node = graph_info[enter_transpose_input_perm].node + transpose_perm_node_content = tensor_util.MakeNdarray( + enter_transpose_input_perm_node.attr["value"].tensor + ) + else: + continue + + if reshape_shape_node.op == "Const": + reshape_shape_node_content = tensor_util.MakeNdarray(reshape_shape_node.attr["value"].tensor) + elif reshape_shape_node.op == "Enter": + enter_reshape_shape = reshape_shape_node.input[0] + enter_reshape_shape_node = graph_info[enter_reshape_shape].node + reshape_shape_node_content = tensor_util.MakeNdarray(enter_reshape_shape_node.attr["value"].tensor) + else: + continue + + converted_node = transpose_input_node_content.transpose(transpose_perm_node_content).reshape( + reshape_shape_node_content + ) + g.remove_node(i[0]) + if transpose_input_node.op == "Const": + g.remove_node(transpose_input_node.name) + g.remove_node(transpose_input_perm.name) + new_node_name = transpose_input_node.name + "_converted" + new_node = Helper.create_constant_node( + new_node_name, converted_node, dtype=dtypes.float32, shape=converted_node.shape + ) + g.replace_const_node(new_node, [i[2]], i[1]) + g.remove_node(i[1]) + g.remove_node(reshape_shape_node.name) + else: + g.remove_node(enter_input_node.name) + g.remove_node(transpose_input_perm.name) + new_node_name = enter_input_node.name + "_converted" + + new_node = Helper.create_constant_node( + new_node_name, converted_node, dtype=dtypes.float32, shape=converted_node.shape + ) + g.add_node(new_node, [], [transpose_input_node.name]) + transpose_input_node.input[0] = new_node.name + for index, node_name in enumerate(graph_info[i[2]].node.input): + if node_name == i[1]: + graph_info[i[2]].node.input[index] = transpose_input_node.name + g.remove_node(i[1]) + g.remove_node(reshape_shape_node.name) + + return g.dump_graph() diff --git a/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/graph_cse_optimizer.py b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/graph_cse_optimizer.py new file mode 100644 index 00000000000..b706e5cce5d --- /dev/null +++ b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/graph_cse_optimizer.py @@ -0,0 +1,145 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""CSE Graph Rewriter.""" + +from tensorflow.core.framework import graph_pb2 + +from neural_compressor.tensorflow.quantization.utils.graph_util import GraphAnalyzer +from neural_compressor.tensorflow.quantization.utils.graph_util import GraphRewriterHelper as Helper +from neural_compressor.tensorflow.utils import dump_elapsed_time + +from ..graph_base import GraphRewriterBase + + +class GraphCseOptimizer(GraphRewriterBase): + """We introduce the CSE optimizer to optimize the nodes that contains identical op type. + + Case 1. Node A has three output nodes(B,C,D) and those child nodes has their own outputs + (B1/C1C2/D1). + Node A + x x x + x x x + NODE B NODE C NODE D + x x x x + B1 C1 C2 D1 + If Node B/C/D have the identical memory-bound op, like relu/relu6. The graph will be + converted as below. + We removed the Node C & Node D, updated the B as the input of C1/C2/D1. + Node A + x + Node B + x x x x + x x x x + x x x x + B1 C1 C2 D1 + Case 2. Node A has three output nodes(B,C,D) and those child nodes has their own outputs + (B1/C1C2/D1). + Node A + x x x + x x x + NODE B NODE C NODE D + x x x x + B1 C1 C2 D1 + If Node B and C have the identical memory-bound op, like relu/relu6. The graph will be + converted as below. + We removed the Node C, updated the B as the input of C1/C2. + Node A + x x + Node B Node D + x x x x + x | x x + x | x x + B1 C1 C2 D1 + + Returns: + [graphdef]: A optimized graphdef object. + """ + + computational_op_type = ("Conv2D", "Conv3D", "DepthwiseConv2dNative", "MatMul") + + @dump_elapsed_time("Pass GraphCseOptimizer") + def do_transformation(self): + """Optimize the graph contains multi output nodes. + + If those nodes' type are identical, those nodes should be elimated. + Currently, we supported memory bound ops only. + + Args: + input_graph_def (graphdef): graphdef object + + Returns: + [graphdef]: optimized graph + """ + GraphAnalyzer().graph = self.model + + graph_info = GraphAnalyzer().parse_graph() + + need_to_update_node = [] + # TODO Enhance below code snippet by using recursive method. + for _, i in graph_info.items(): + candidate_node = [ + graph_info[child_name].node + for child_name in i.outputs + if graph_info[child_name].node.op not in self.computational_op_type + ] + candidate_node_unique_type = set([i.op for i in candidate_node]) + if len(candidate_node_unique_type) == len(candidate_node): + # it means each sub node has their own type. + continue + node_type_name_mapping = {} + # Created dict which key is op type and value is node has identical op type. + for each_node in candidate_node: + node_type = each_node.op + node_name = each_node.name + if node_type not in node_type_name_mapping: + node_type_name_mapping[node_type] = [node_name] + else: + node_type_name_mapping[node_type].append(node_name) + + for _, node_names in node_type_name_mapping.items(): + # ignore unique op type and node with multi-outputs + if len(node_names) == 1 or len(graph_info[node_names[0]].outputs) > 1: + continue + # TODO Need to enhance below algorithm before golden. + filter_node = [node_names[0]] + for sub_node_name in node_names[1:]: + if not Helper.compare_node_attr(graph_info[node_names[0]].node, graph_info[sub_node_name].node): + continue + filter_node.append(sub_node_name) + + need_to_update_node.append({i.node.name: filter_node}) + + for node_pair in need_to_update_node: + for upper_node_name, lower_node_name in node_pair.items(): + keep_sub_node_name = lower_node_name[0] + for removeable_node_name in lower_node_name[1:]: + graph_info[upper_node_name].outputs.remove(removeable_node_name) + for grand_child_node_name in graph_info[removeable_node_name].outputs: + filter_input_name = [ + Helper.node_name_from_input(i) for i in graph_info[grand_child_node_name].node.input + ] + replace_index = filter_input_name.index(removeable_node_name) + graph_info[grand_child_node_name].node.input[replace_index] = keep_sub_node_name + graph_info[grand_child_node_name].node.input[replace_index] = keep_sub_node_name + graph_info.pop(removeable_node_name) + + output_graph_def = graph_pb2.GraphDef() + + for _, node_info in graph_info.items(): + output_graph_def.node.extend([node_info.node]) + + return output_graph_def diff --git a/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/grappler_pass.py b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/grappler_pass.py new file mode 100644 index 00000000000..aee3790f0ba --- /dev/null +++ b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/grappler_pass.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tensorflow Grappler Graph Rewriter.""" + +import tensorflow as tf +from tensorflow.core.protobuf import config_pb2, meta_graph_pb2 +from tensorflow.python.grappler import tf_optimizer +from tensorflow.python.training import saver + +from neural_compressor.tensorflow.quantization.utils.graph_rewriter.graph_base import GraphRewriterBase +from neural_compressor.tensorflow.utils import dump_elapsed_time, version1_gt_version2 + + +class GrapplerOptimizer(GraphRewriterBase): + """A python wrapper that leverages the built-in tensorflow grappler API to optimize the graph.""" + + def __init__(self, model, input_output_names, opt_cfg): + """Initialization.""" + super().__init__(model) + self.input_output_names = input_output_names + self.opt_cfg = opt_cfg + self.generic_optimizer = ("pruning", "shape", "dependency", "debug_stripper", "loop") + self.tf_2_optimizer = ("constfold", "arithmetic", "min_graph_nodes") + + @dump_elapsed_time("Pass GrapplerOptimizer") + def do_transformation(self): + """Apply tensorflow Grappler optimization.""" + try: + g = tf.Graph() + with g.as_default(): + g = tf.compat.v1.import_graph_def(self.model, name="") + meta_graph = saver.export_meta_graph(graph_def=self.model, graph=g, clear_devices=True) + fetch_collection = meta_graph_pb2.CollectionDef() + for fetch in self.input_output_names: + fetch_collection.node_list.value.append(fetch) + meta_graph.collection_def["train_op"].CopyFrom(fetch_collection) + config = config_pb2.ConfigProto() + rewriter_config = config.graph_options.rewrite_options + for optimizer in self.generic_optimizer: + if optimizer in self.opt_cfg and self.opt_cfg[optimizer]: + rewriter_config.optimizers.append(optimizer) + + if version1_gt_version2(tf.version.VERSION, "2.2.0"): + for optimizer in self.tf_2_optimizer: + if optimizer in self.opt_cfg and self.opt_cfg[optimizer]: + rewriter_config.optimizers.append(optimizer) + + rewriter_config.min_graph_nodes = -1 + + optimized_graph = tf_optimizer.OptimizeGraph(config, meta_graph) + + return optimized_graph + except Exception as e: + self.logger.warning("Fail to run grappler pass due to {}.".format(str(e))) + return self.model diff --git a/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/insert_print_node.py b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/insert_print_node.py new file mode 100644 index 00000000000..a60fbb701e3 --- /dev/null +++ b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/insert_print_node.py @@ -0,0 +1,223 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Insert print node Graph Rewriter.""" + +import tensorflow as tf +from tensorflow.core.framework import attr_value_pb2 +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import tensor_util as tu + +from neural_compressor.tensorflow.quantization.utils.graph_rewriter.graph_base import GraphRewriterBase +from neural_compressor.tensorflow.quantization.utils.graph_util import GraphAnalyzer +from neural_compressor.tensorflow.quantization.utils.graph_util import GraphRewriterHelper as Helper +from neural_compressor.tensorflow.utils import version1_gt_version2 + + +class InsertPrintMinMaxNode(GraphRewriterBase): + """InsertPrintMinMaxNode Pass for tensorflow sampling.""" + + def __init__(self, model, pre_node_name, post_node_name, new_api): + """Initialization.""" + super().__init__(model) + self.pre_node_name = pre_node_name + self.post_node_name = post_node_name + self.signature = pre_node_name + post_node_name + self.new_api = new_api + + def do_transformation(self): + """Insert print node in the graph to do the calibration.""" + cur_graph = GraphAnalyzer() + cur_graph.graph = self.model + + graph_info = cur_graph.parse_graph() + insert_node_pairs = [] + top_node = graph_info[self.pre_node_name].node + if top_node.op == "ConcatV2": + for i in range(top_node.attr["N"].i): + insert_node_pairs.append([top_node.input[i], self.post_node_name]) + elif top_node.op in ("BatchMatMul", "BatchMatMulV2"): + insert_node_pairs.append([top_node.input[0], self.post_node_name]) + if graph_info[top_node.input[1]].node.op != "Const": + insert_node_pairs.append([top_node.input[1], self.post_node_name]) + elif top_node.op in ("Conv2DBackpropInput", "Conv3DBackpropInputV2"): + insert_node_pairs.append([top_node.input[2], self.post_node_name]) + else: + refresh_pre_node_name = graph_info[self.pre_node_name].node.input[0] + # Check the Conv2D could be fused with previous Pad or not. + # If so, we need to update the pre-node name correspondingly. + refresh_pre_node = graph_info[Helper.node_name_from_input(refresh_pre_node_name)].node + if refresh_pre_node.op == "Pad" and top_node.op in ("Conv2D", "Conv3D"): + pad_const_node_name = refresh_pre_node.input[1] + pad_const_node = graph_info[pad_const_node_name].node + padding_tensor = None + if graph_info[pad_const_node_name].node.op != "Const": + if pad_const_node.op == "DataFormatVecPermute": + parent_input_node = graph_info[pad_const_node.input[0]].node + if parent_input_node.op == "Const": + padding_tensor = tu.MakeNdarray(parent_input_node.attr["value"].tensor).flatten() + else: + padding_tensor = tu.MakeNdarray(pad_const_node.attr["value"].tensor).flatten() + if not any(padding_tensor) or ( + any(padding_tensor) and (tf.version.VERSION == "1.15.0-up3" or self.new_api) + ): + insert_node_pairs.append([refresh_pre_node_name, self.post_node_name]) + refresh_pre_node_name = refresh_pre_node.input[0] + + insert_node_pairs.append([refresh_pre_node_name, self.post_node_name]) + + output_names = [] + for node_pair_names in insert_node_pairs: + for index, each_node_name in enumerate(node_pair_names): + name_with_sig = each_node_name + self.signature + node_name_prefix = name_with_sig.replace(":", "__port__").replace("^", "__hat__") + reshape_dims_name = node_name_prefix + "_reshape_dims" + reduction_dims_name = node_name_prefix + "_reduction_dims" + + reshape_dims_node = Helper.create_constant_node(reshape_dims_name, -1, dtypes.int32, [1]) + + reduction_dims_node = Helper.create_constant_node(reduction_dims_name, 0, dtypes.int32, [1]) + + # the training input QueueDequeueManyV2 has issue with implicit dependency + # skip the input node of show_and_tell model + if not ( + Helper.node_name_from_input(each_node_name) == "batch_and_pad" + and graph_info[Helper.node_name_from_input(each_node_name)].node.op == "QueueDequeueManyV2" + ): + reshape_dims_node.input.append("^" + Helper.node_name_from_input(each_node_name)) + reduction_dims_node.input.append("^" + Helper.node_name_from_input(each_node_name)) + + reshape_input_name = node_name_prefix + "_reshape_" + + reshape_input_node = Helper.create_node( + "Reshape", reshape_input_name, [each_node_name, reshape_dims_name] + ) + + min_input_name = node_name_prefix + "_min" + min_input_node = Helper.create_node("Min", min_input_name, [reshape_input_name, reduction_dims_name]) + Helper.set_attr_dtype(min_input_node, "Tidx", dtypes.int32) + Helper.set_attr_bool(min_input_node, "keep_dims", False) + + max_input_name = node_name_prefix + "_max" + max_input_node = Helper.create_node("Max", max_input_name, [reshape_input_name, reduction_dims_name]) + Helper.set_attr_dtype(max_input_node, "Tidx", dtypes.int32) + Helper.set_attr_bool(max_input_node, "keep_dims", False) + + max_print_node = Helper.create_node( + "Print", + node_name_prefix + "_print_max__{}".format(index), + [max_input_name + ":0", max_input_name + ":0"], + ) + min_print_node = Helper.create_node( + "Print", + node_name_prefix + "_print_min__{}".format(index), + [min_input_name + ":0", min_input_name + ":0"], + ) + + if index == 0: + max_msg = ";{}_eightbit_max_{}__print__;__max:".format(self.pre_node_name, each_node_name) + min_msg = ";{}_eightbit_min_{}__print__;__min:".format(self.pre_node_name, each_node_name) + # workaround for swish_f32, attribute T is not in the op definition + if "swish_f32" in graph_info[self.pre_node_name].node.name: + src_dt = attr_value_pb2.AttrValue(type=dtypes.float32.as_datatype_enum) + else: + src_dt = graph_info[self.pre_node_name].node.attr["T"] + else: + max_msg = ";{}_eightbit_requant_range__print__;__requant_max:".format(self.pre_node_name) + min_msg = ";{}_eightbit_requant_range__print__;__requant_min:".format(self.pre_node_name) + # workaround for swish_f32, attribute T is not in the op definition + if "swish_f32" in graph_info[each_node_name].node.op: + src_dt = attr_value_pb2.AttrValue(type=dtypes.float32.as_datatype_enum) + else: + src_dt = graph_info[each_node_name].node.attr["T"] + + reshape_input_node.attr["T"].CopyFrom(src_dt) + min_input_node.attr["T"].CopyFrom(src_dt) + min_print_node.attr["T"].CopyFrom(src_dt) + max_input_node.attr["T"].CopyFrom(src_dt) + max_print_node.attr["T"].CopyFrom(src_dt) + + min_print_node.attr["message"].s = min_msg.encode() + min_print_node.attr["first_n"].i = -1 + min_print_node.attr["summarize"].i = 1024 + + max_print_node.attr["message"].s = max_msg.encode() + max_print_node.attr["first_n"].i = -1 + max_print_node.attr["summarize"].i = 1024 + + attr_u = [dtypes.as_dtype(src_dt.type).as_datatype_enum] + min_print_node.attr["U"].list.CopyFrom(attr_value_pb2.AttrValue.ListValue(type=attr_u)) + max_print_node.attr["U"].list.CopyFrom(attr_value_pb2.AttrValue.ListValue(type=attr_u)) + post_node_names = graph_info[Helper.node_name_from_input(each_node_name)].outputs + if post_node_names: + for post_node_name in post_node_names: + post_node = graph_info[post_node_name].node + if each_node_name not in post_node.input: + continue + if ( + post_node.op == "FusedBatchNormV3" + and "_print_identity" + not in graph_info[Helper.node_name_from_input(post_node.name)].node.input[0] + ): + identity_node = Helper.create_node( + "Identity", + post_node.name + "_print_identity", + [graph_info[Helper.node_name_from_input(post_node.name)].node.input[0]], + ) + identity_node.attr["T"].CopyFrom(src_dt) + cur_graph.add_node( + identity_node, + graph_info[Helper.node_name_from_input(post_node.name)].node.input[0], + [post_node.name], + ) + identity_node.input.append("^" + min_print_node.name) + identity_node.input.append("^" + max_print_node.name) + else: + post_node.input.append("^" + min_print_node.name) + post_node.input.append("^" + max_print_node.name) + + cur_graph.add_node(reshape_dims_node, None, [reshape_input_name]) + cur_graph.add_node(reduction_dims_node, None, [max_input_name, min_input_name]) + cur_graph.add_node(reshape_input_node, each_node_name, [max_input_name, min_input_name]) + cur_graph.add_node(max_input_node, reshape_input_name, [max_print_node.name]) + cur_graph.add_node(min_input_node, reshape_input_name, [min_print_node.name]) + + cur_graph.add_node(min_print_node, min_input_name, []) + cur_graph.add_node(max_print_node, max_input_name, []) + else: + identity_node0 = Helper.create_node( + "Identity", min_print_node.name + "_identity", [min_print_node.name] + ) + identity_node0.attr["T"].CopyFrom(src_dt) + identity_node1 = Helper.create_node( + "Identity", max_print_node.name + "_identity", [max_print_node.name] + ) + identity_node1.attr["T"].CopyFrom(src_dt) + + cur_graph.add_node(reshape_dims_node, None, [reshape_input_name]) + cur_graph.add_node(reduction_dims_node, None, [max_input_name, min_input_name]) + cur_graph.add_node(reshape_input_node, each_node_name, [max_input_name, min_input_name]) + cur_graph.add_node(max_input_node, reshape_input_name, [max_print_node.name]) + cur_graph.add_node(min_input_node, reshape_input_name, [min_print_node.name]) + cur_graph.add_node(min_print_node, min_input_name, [identity_node0.name]) + cur_graph.add_node(max_print_node, max_input_name, [identity_node1.name]) + cur_graph.add_node(identity_node0, min_print_node.name, []) + cur_graph.add_node(identity_node1, max_print_node.name, []) + # identity_node0.input.append("^" + min_print_node.name) + # identity_node1.input.append("^" + max_print_node.name) + output_names.append(identity_node0.name) + output_names.append(identity_node1.name) + return cur_graph.dump_graph(), output_names diff --git a/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/move_squeeze_after_relu.py b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/move_squeeze_after_relu.py new file mode 100644 index 00000000000..6b072617202 --- /dev/null +++ b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/move_squeeze_after_relu.py @@ -0,0 +1,123 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Move Squeeze after Relu Graph Rewriter.""" + +import copy + +from neural_compressor.tensorflow.quantization.utils.graph_util import GraphAnalyzer +from neural_compressor.tensorflow.utils import dump_elapsed_time + +from ..graph_base import GraphRewriterBase + + +class MoveSqueezeAfterReluOptimizer(GraphRewriterBase): + """Move Squeeze op after Relu op for match fusion pattern.""" + + def __init__(self, model): + """Initialization.""" + super().__init__(model) + self.op_list = ["Relu", "Sigmoid", "Relu6", "LeakyRelu", "Elu"] + + @dump_elapsed_time("Pass MoveSqueezeAfterReluOptimizer") + def do_transformation(self): + """Move Squeeze/Reshape after Relu.""" + g = GraphAnalyzer() + g.graph = self.model + graph_info = g.parse_graph() + # For pattern Conv + Squeeze + BiasAdd + Relu(Sigmoid, Relu6, LeakyRelu, Elu) + for node in self.model.node: + if ( + node.op in self.op_list + and node.input[0] in graph_info + and graph_info[node.input[0]].node.op == "BiasAdd" + ): + biasadd_node = graph_info[node.input[0]].node + biasadd_input = graph_info[biasadd_node.name].node.input[0] + squeeze_node = graph_info[biasadd_input].node + relu_output = graph_info[node.name].outputs + if squeeze_node.op == "Squeeze": + # biasadd + for i, input in enumerate(biasadd_node.input): + if input == biasadd_input: + new_input = biasadd_node.input[:i] + [squeeze_node.input[0]] + biasadd_node.input[i + 1 :] + graph_info[biasadd_node.name].node.ClearField("input") + graph_info[biasadd_node.name].node.input.extend(new_input) + graph_info[squeeze_node.name].outputs.remove(biasadd_node.name) + # conv output + conv = squeeze_node.input[0] + conv_outputs = graph_info[conv].outputs + for i, output in enumerate(conv_outputs): + if output == squeeze_node.name: + graph_info[conv].outputs.remove(squeeze_node.name) + graph_info[conv].outputs.append(biasadd_node.name) + # squeeze input + squeeze_node.ClearField("input") + squeeze_node.input.extend([node.name]) + # expand input,squeeze output + for output in relu_output: + for i, input in enumerate(graph_info[output].node.input): + if input == node.name: + new_input = ( + graph_info[output].node.input[:i] + + [squeeze_node.name] + + graph_info[output].node.input[i + 1 :] + ) + graph_info[squeeze_node.name].outputs.append(output) + graph_info[output].node.ClearField("input") + graph_info[output].node.input.extend(new_input) + + # For pattern x + Reshape + Relu(Sigmoid, Relu6, LeakyRelu, Elu) + if ( + node.op in self.op_list + and node.input[0] in graph_info + and graph_info[node.input[0]].node.op == "Reshape" + ): + reshape_node = graph_info[node.input[0]].node + reshape_input = graph_info[reshape_node.name].node.input[0] + x_node = graph_info[reshape_input].node + relu_output = copy.deepcopy(graph_info[node.name].outputs) + + if len(graph_info[x_node.name].outputs) != 1: + continue + if len(graph_info[reshape_node.name].outputs) > 1: + continue + # relu---->reshape + for i, input in enumerate(reshape_node.input): + if input == reshape_input: + new_input = reshape_node.input[:i] + [node.name] + reshape_node.input[i + 1 :] + graph_info[reshape_node.name].node.ClearField("input") + graph_info[reshape_node.name].node.input.extend(new_input) + graph_info[x_node.name].outputs.remove(reshape_node.name) + graph_info[x_node.name].outputs.append(node.name) + # x----->relu + node.ClearField("input") + node.input.extend([reshape_input]) + # expand input,squeeze output + for output in relu_output: + for i, input in enumerate(graph_info[output].node.input): + if input == node.name: + new_input = ( + graph_info[output].node.input[:i] + + [reshape_node.name] + + graph_info[output].node.input[i + 1 :] + ) + graph_info[reshape_node.name].outputs.append(output) + graph_info[output].node.ClearField("input") + graph_info[output].node.input.extend(new_input) + graph_info[node.name].outputs.remove(output) + graph_info[node.name].outputs.append(reshape_node.name) + return g.dump_graph() diff --git a/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/pre_optimize.py b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/pre_optimize.py new file mode 100644 index 00000000000..44e20f20cc3 --- /dev/null +++ b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/pre_optimize.py @@ -0,0 +1,310 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Pre Optimization Entrance.""" + +import copy +import logging + +import tensorflow as tf + +from neural_compressor.tensorflow.quantization.utils.graph_util import GraphAnalyzer +from neural_compressor.tensorflow.utils import ( + dump_elapsed_time, + version1_eq_version2, + version1_gte_version2, + version1_lt_version2, +) + +from .convert_add_to_biasadd import ConvertAddToBiasAddOptimizer +from .convert_layout import ConvertLayoutOptimizer +from .convert_leakyrelu import ConvertLeakyReluOptimizer +from .convert_nan_to_random import ConvertNanToRandom +from .convert_placeholder_to_const import ConvertPlaceholderToConst +from .dilated_contraction import DilatedContraction +from .dummy_biasadd import InjectDummyBiasAddOptimizer +from .expanddims_optimizer import ExpandDimsOptimizer +from .fetch_weight_from_reshape import FetchWeightFromReshapeOptimizer +from .fold_batch_norm import FoldBatchNormNodesOptimizer +from .fold_constant import GraphFoldConstantOptimizer +from .fuse_biasadd_add import FuseBiasAddAndAddOptimizer +from .fuse_column_wise_mul import FuseColumnWiseMulOptimizer +from .fuse_conv_with_math import FuseConvWithMathOptimizer +from .fuse_decomposed_bn import FuseDecomposedBNOptimizer +from .fuse_decomposed_in import FuseDecomposedINOptimizer +from .fuse_gelu import FuseGeluOptimizer +from .fuse_layer_norm import FuseLayerNormOptimizer +from .fuse_reshape_transpose import FuseTransposeReshapeOptimizer +from .graph_cse_optimizer import GraphCseOptimizer +from .grappler_pass import GrapplerOptimizer +from .move_squeeze_after_relu import MoveSqueezeAfterReluOptimizer +from .remove_training_nodes import RemoveTrainingNodesOptimizer +from .rename_batch_norm import RenameBatchNormOptimizer +from .split_shared_input import SplitSharedInputOptimizer +from .strip_equivalent_nodes import StripEquivalentNodesOptimizer +from .strip_unused_nodes import StripUnusedNodesOptimizer +from .switch_optimizer import SwitchOptimizer + + +class PreOptimization: + """Pre optimization for the FP32 models.""" + + def __init__(self, model, new_api, device): + """Initialization.""" + self.model = model + if version1_gte_version2(tf.version.VERSION, "2.1.0") or version1_eq_version2(tf.version.VERSION, "1.15.0-up3"): + self.optimization = { + "pruning": True, + "shape": True, + "constfold": False, + "arithmetic": False, + "dependency": True, + "debug_stripper": True, + "loop": True, + } + else: + self.optimization = { + "pruning": True, + "shape": True, + "dependency": True, + "debug_stripper": True, + "loop": True, + } + # Table initialization should disable grappler dependency and pruning pass + node_names = [node.name for node in model.graph_def.node] + if "init_all_tables" in node_names: + self.optimization["dependency"] = False + self.optimization["pruning"] = False + self.new_api = new_api + self.device = device + self.analyzer = GraphAnalyzer() + self.analyzer.graph = model.graph_def + self.analyzer.parse_graph() + self._tmp_graph_def = None + self._excluded_node_names = [] + + def get_excluded_node_names(self): + """Get the excluded node name. + + Returns: + string list: the excluded ops' name + """ + return self._excluded_node_names + + @dump_elapsed_time("Pass Pre Optimization") + def get_optimized_model(self, itex_mode=False): + """Executed the non-precision dependent graph optimization. + + The input graph will be optimized with following passes: + 1. Remove the training nodes like Identity Op. + 2. Split the shared nodes like weights node for multi-Conv2d. + 3. Fold Constant Nodes as less as possible. + 4. Fuse the Mul node into the previous Conv2D/MatMul if possible. + 5. Strip the useless nodes. + 6. Do the Common sequence elimation optimization on the graph. + 7. Fold the BN node into the previous Conv2D if possible. + + Returns: + [graphdef]: the optimized graphdef object. + """ + from neural_compressor.tensorflow.utils import Model + + origin_model = Model(self.model._model, **self.model.kwargs, backend="itex" if itex_mode else "default") + origin_model.name = self.model.name + origin_model.model_type = self.model.model_type + origin_model.output_tensor_names = self.model.output_tensor_names + origin_model.input_tensor_names = self.model.input_tensor_names + origin_model.workspace_path = self.model.workspace_path + + output_node_names = self.model.output_node_names + input_node_names = self.model.input_node_names + input_output_names = output_node_names + input_node_names + + # Add device info before convert layout + # Google in layout optimizer where all nodes in the graph are expected to have their device + # information set (earlier version < 2.10.0 this was not needed). + if version1_gte_version2(tf.version.VERSION, "2.10.0"): + cur_graph = GraphAnalyzer() + cur_graph.graph = self.model.graph_def + graph_info = cur_graph.parse_graph() + + if self.device == "cpu": + cpus = tf.config.list_physical_devices("CPU") + node_device = cpus[0].name.replace("physical_device:", "") + else: + gpus = tf.config.list_physical_devices("GPU") + if len(gpus) == 0: + xpus = tf.config.list_physical_devices("XPU") + if len(xpus) == 0: + cpus = tf.config.list_physical_devices("CPU") + node_device = cpus[0].name.replace("physical_device:", "") + else: + node_device = xpus[0].name.replace("physical_device:", "") + else: + node_device = gpus[0].name.replace("physical_device:", "") + for node_name in list(graph_info.keys()): + node = graph_info[node_name].node + node.device = node_device + self._tmp_graph_def = cur_graph.dump_graph() + + self._tmp_graph_def = ConvertLayoutOptimizer(self._tmp_graph_def, output_node_names).do_transformation() + else: + self._tmp_graph_def = ConvertLayoutOptimizer(self.model.graph_def, output_node_names).do_transformation() + + self._tmp_graph_def = ConvertPlaceholderToConst(self._tmp_graph_def).do_transformation() + + self._tmp_graph_def = SwitchOptimizer(self._tmp_graph_def).do_transformation() + + self._tmp_graph_def = GrapplerOptimizer( + self._tmp_graph_def, input_output_names, self.optimization + ).do_transformation() + + self._tmp_graph_def = StripUnusedNodesOptimizer( + self._tmp_graph_def, input_node_names, output_node_names + ).do_transformation() + + self._tmp_graph_def = RemoveTrainingNodesOptimizer( + self._tmp_graph_def, protected_nodes=input_output_names + ).do_transformation() + + self._tmp_graph_def = SplitSharedInputOptimizer(self._tmp_graph_def).do_transformation() + + # Put FuseDecomposedBNOptimizer before GraphFoldConstantOptimizer + # The 'Sub' op in the small decomposed ops of BN will be converted to const by GraphFoldConstantOptimizer. + # Then the FuseDecomposedBNOptimizer can't fuse the small decomposed ops to BN. + if self.new_api: + self._tmp_graph_def = FuseDecomposedBNOptimizer(self._tmp_graph_def).do_transformation() + self._tmp_graph_def = FuseDecomposedINOptimizer(self._tmp_graph_def).do_transformation() + self._tmp_graph_def = FuseLayerNormOptimizer(self._tmp_graph_def).do_transformation() + + self._tmp_graph_def = GraphFoldConstantOptimizer(self._tmp_graph_def).do_transformation() + + if not self.new_api: + self._tmp_graph_def = FuseDecomposedBNOptimizer(self._tmp_graph_def).do_transformation() + + self._tmp_graph_def = FuseColumnWiseMulOptimizer(self._tmp_graph_def).do_transformation() + + self._tmp_graph_def = StripUnusedNodesOptimizer( + self._tmp_graph_def, input_node_names, output_node_names + ).do_transformation() + + self._tmp_graph_def = FuseGeluOptimizer(self._tmp_graph_def).do_transformation() + + self._tmp_graph_def = GraphCseOptimizer(self._tmp_graph_def).do_transformation() + + self._tmp_graph_def = FoldBatchNormNodesOptimizer(self._tmp_graph_def).do_transformation() + + self._tmp_graph_def = RenameBatchNormOptimizer(self._tmp_graph_def).do_transformation() + + self._tmp_graph_def = ConvertLeakyReluOptimizer(self._tmp_graph_def).do_transformation() + + self._tmp_graph_def = ConvertAddToBiasAddOptimizer(self._tmp_graph_def).do_transformation() + + self._tmp_graph_def = FuseTransposeReshapeOptimizer(self._tmp_graph_def).do_transformation() + + self._tmp_graph_def = FuseConvWithMathOptimizer(self._tmp_graph_def).do_transformation() + + self._tmp_graph_def = ExpandDimsOptimizer(self._tmp_graph_def).do_transformation() + + self._tmp_graph_def = FetchWeightFromReshapeOptimizer(self._tmp_graph_def).do_transformation() + + self._tmp_graph_def = MoveSqueezeAfterReluOptimizer(self._tmp_graph_def).do_transformation() + + if not self.new_api and not itex_mode: + # TODO we need to remove below optimizer once the TF enabled the single + # matmul op quantization + self._tmp_graph_def = InjectDummyBiasAddOptimizer( + self._tmp_graph_def, output_node_names + ).do_transformation() + + self._tmp_graph_def = FuseBiasAddAndAddOptimizer(self._tmp_graph_def).do_transformation() + + self._tmp_graph_def = ConvertNanToRandom(self._tmp_graph_def).do_transformation() + + self._tmp_graph_def = StripEquivalentNodesOptimizer(self._tmp_graph_def, output_node_names).do_transformation() + + if self.new_api or itex_mode: + self._tmp_graph_def = DilatedContraction(self._tmp_graph_def).do_transformation() + + # node device info will be removed by GrapplerOptimizer, insert it again. + if version1_lt_version2(tf.version.VERSION, "2.0.0"): # pragma: no cover + from tensorflow._api.v1.config import experimental + + list_physical_devices = experimental.list_physical_devices + else: + list_physical_devices = tf.config.list_physical_devices + cur_graph = GraphAnalyzer() + cur_graph.graph = self._tmp_graph_def + graph_info = cur_graph.parse_graph() + + if self.device == "cpu": + cpus = list_physical_devices("CPU") + node_device = cpus[0].name.replace("physical_device:", "") + else: + gpus = list_physical_devices("GPU") + if len(gpus) == 0: + xpus = list_physical_devices("XPU") + if len(xpus) == 0: + cpus = list_physical_devices("CPU") + node_device = cpus[0].name.replace("physical_device:", "") + else: + node_device = xpus[0].name.replace("physical_device:", "") + else: + node_device = gpus[0].name.replace("physical_device:", "") + for node_name in list(graph_info.keys()): + node = graph_info[node_name].node + node.device = node_device + self._tmp_graph_def = cur_graph.dump_graph() + + self._tmp_graph_def.library.CopyFrom(self.model.graph_def.library) + + for function_def in self.model.graph_def.library.function: + if function_def.signature.name == "swish_f32": + self._tmp_graph_def.library.function.extend([copy.deepcopy(function_def)]) + + origin_model.graph_def = self._tmp_graph_def + return origin_model + + def get_matched_nodes(self, patterns): + """Search the matched nodes with the specified patterns. + + Args: + patterns ([string list]): The patterns should be illustrated as below. + [['MatMul'], ("BiasAdd"), ("Relu",)] + + Returns: + [string list]: It will return the list that contains the matched nodes name + and pattern. ['matched_node_a_name', 'matched_node_a_name',['MatMul','BiasAdd']] + """ + self.analyzer.graph = self._tmp_graph_def + self.analyzer.parse_graph() + res = [] + + for sub_pattern in patterns: + res.extend([i for i in self.analyzer.query_fusion_pattern_nodes(sub_pattern) if i not in res]) + return res + + def has_positive_input(self, node_name): + """Check the specified node has the positive input or not. + + Args: + node_name ([string]): node's name + + Returns: + [bool]: True if the node has the positive input data, + False if the node has the negative input data. + """ + return self.analyzer.has_positive_input(node_name) diff --git a/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/remove_training_nodes.py b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/remove_training_nodes.py new file mode 100644 index 00000000000..76d83bea651 --- /dev/null +++ b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/remove_training_nodes.py @@ -0,0 +1,66 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Remove training nodes Graph Rewriter.""" + +from neural_compressor.tensorflow.quantization.utils.graph_util import GraphAnalyzer +from neural_compressor.tensorflow.utils import dump_elapsed_time + +from ..graph_base import GraphRewriterBase + + +class RemoveTrainingNodesOptimizer(GraphRewriterBase): + """Remove training nodes optimizer.""" + + def __init__(self, model, protected_nodes=[], types_to_splice=["Identity", "CheckNumerics", "StopGradient"]): + """Initilizaiton.""" + super().__init__(model) + self.protected_nodes = protected_nodes + self.types_to_splice = types_to_splice + + @dump_elapsed_time("Pass RemoveTrainingNodesOptimizer") + def do_transformation(self): + """Remove tranining nodes which has no control edge inputs.""" + graph_handle = GraphAnalyzer() + graph_handle.graph = self.model + + graph_info = graph_handle.parse_graph() + # input_nodes = input_graph.node + + control_input_names = set() + node_names_with_control_input = set() + names_to_splice = {} + + for node_name, v in graph_info.items(): + for node_input in v.node.input: + if "^" in node_input: + control_input_names.add(node_input.replace("^", "")) + node_names_with_control_input.add(node_name) + + for node_name, v in graph_info.items(): + if v.node.op in self.types_to_splice and v.node.name not in self.protected_nodes: + # We don't want to remove nodes that have control edge inputs, because + # they might be involved in subtle dependency issues that removing them + # will jeopardize. + if node_name not in node_names_with_control_input: + names_to_splice[node_name] = v.node.input[0] + + # We also don't want to remove nodes which are used as control edge inputs. + names_to_splice = {name: value for name, value in names_to_splice.items() if name not in control_input_names} + for k, _ in names_to_splice.items(): + graph_handle.remove_node_with_single_input_output(k) + + return graph_handle.dump_graph() diff --git a/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/rename_batch_norm.py b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/rename_batch_norm.py new file mode 100644 index 00000000000..8e7cbc076df --- /dev/null +++ b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/rename_batch_norm.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Rename FusedBatchNorm op to FusedBatchNormV2 Graph Rewriter.""" + +import math + +import numpy as np +from tensorflow.core.framework import attr_value_pb2, node_def_pb2 +from tensorflow.python.framework import tensor_util + +from neural_compressor.tensorflow.quantization.utils.graph_util import GraphAnalyzer +from neural_compressor.tensorflow.quantization.utils.graph_util import GraphRewriterHelper as Helper +from neural_compressor.tensorflow.utils import dump_elapsed_time + +from ..graph_base import GraphRewriterBase + + +class RenameBatchNormOptimizer(GraphRewriterBase): + """Rename FusedBatchNorm op to FusedBatchNormV2.""" + + @dump_elapsed_time("Pass RenameBatchNormOptimizer") + def do_transformation(self): + """Rename FusedBatchNorm op to FusedBatchNormV2. + + This pass is needed for bf16 conversion. Due to TensorFlow historical reason, + FusedBatchNorm is not a bf16 op but FusedBatchNormV2 is. As the latter is compatible + with the former, changing FusedBatchNorm op to FusedBatchNormV2 op will be able to + convert to bf16 op on the platforms supporting VNNI_BF16 and AMX instructions. + + Returns: + Modified graph with BN ops renamed. + + Raises: + ValueError: If the graph is badly formed. + """ + cur_graph = GraphAnalyzer() + cur_graph.graph = self.model + graph_details = cur_graph.parse_graph() + + for _, v in graph_details.items(): + # for node in cur_graph.graph.node: + if v.node.op == "FusedBatchNorm" or v.node.op == "FusedBatchNormV2": + v.node.op = "FusedBatchNormV3" + v.node.attr["U"].CopyFrom(v.node.attr["T"]) + + return cur_graph.dump_graph() diff --git a/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/split_shared_input.py b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/split_shared_input.py new file mode 100644 index 00000000000..0386a3a8514 --- /dev/null +++ b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/split_shared_input.py @@ -0,0 +1,60 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Split shared input Graph Rewriter.""" + +from tensorflow.core.framework import node_def_pb2 + +from neural_compressor.tensorflow.quantization.utils.graph_util import GraphAnalyzer +from neural_compressor.tensorflow.quantization.utils.graph_util import GraphRewriterHelper as Helper +from neural_compressor.tensorflow.utils import dump_elapsed_time + +from ..graph_base import GraphRewriterBase + + +class SplitSharedInputOptimizer(GraphRewriterBase): + """Split the shared input if the input node is shared and const.""" + + @dump_elapsed_time("Pass SplitSharedInputOptimizer") + def do_transformation(self): + """Execute splitting the shared input.""" + cur_graph = GraphAnalyzer() + cur_graph.graph = self.model + + graph_info = cur_graph.parse_graph() + + is_shared_input = False + # map of: input_name - op_name + input_map = {} + for node_name in list(graph_info.keys()): + node = graph_info[node_name].node + for _, input_node_name in enumerate(node.input): + if input_node_name.startswith("^"): + continue + if graph_info[Helper.node_name_from_input(input_node_name)].node.op == "Const": + # is shared and current node is not the first one + # sharing the input + if input_node_name in input_map: + is_shared_input = True + input_map[input_node_name].append(node.name) + new_input_node = node_def_pb2.NodeDef() + new_input_node.CopyFrom(graph_info[input_node_name].node) + new_input_node.name = input_node_name + "_nc_share_" + str(len(input_map[input_node_name])) + cur_graph.replace_const_node(new_input_node, [node.name], input_node_name, False) + else: + input_map[input_node_name] = [node.name] + + return cur_graph.dump_graph() if is_shared_input else self.model diff --git a/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/strip_equivalent_nodes.py b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/strip_equivalent_nodes.py new file mode 100644 index 00000000000..53090ae55f3 --- /dev/null +++ b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/strip_equivalent_nodes.py @@ -0,0 +1,50 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Strip Equivalent Nodes Graph Rewriter.""" + +from neural_compressor.common import logger +from neural_compressor.tensorflow.quantization.utils.utility import fix_ref_type_of_graph_def, strip_equivalent_nodes +from neural_compressor.tensorflow.utils import dump_elapsed_time + +from ..graph_base import GraphRewriterBase + + +class StripEquivalentNodesOptimizer(GraphRewriterBase): + """Remove the equivalent nodes which have the same inputs and attributes.""" + + def __init__(self, model, output_node_names): + """Initialization.""" + super().__init__(model) + self.output_node_names = output_node_names + + @dump_elapsed_time("Pass StripEquivalentNodesOptimizer") + def do_transformation(self): + """Strip the equivalent nodes in the graph.""" + self.model = fix_ref_type_of_graph_def(self.model) + iter_num = 0 + replaced_nodes_type = True + all_replaced_nodes_type = {} + while replaced_nodes_type: + self.model, replaced_nodes_type = strip_equivalent_nodes(self.model, self.output_node_names) + for k, v in replaced_nodes_type.items(): + all_replaced_nodes_type[k] = all_replaced_nodes_type.get(k, 0) + v + iter_num += 1 + logger.debug( + f"StripEquivalentNodes[Iter-{iter_num}]-Replaced equivalent node types are {replaced_nodes_type}" + ) + logger.warning("All replaced equivalent node types are {}".format(all_replaced_nodes_type)) + return self.model diff --git a/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/strip_unused_nodes.py b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/strip_unused_nodes.py new file mode 100644 index 00000000000..e59a81a4b75 --- /dev/null +++ b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/strip_unused_nodes.py @@ -0,0 +1,42 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Strip unused nodes Graph Rewriter.""" + +from neural_compressor.tensorflow.utils import dump_elapsed_time + +from ..graph_base import GraphRewriterBase + + +class StripUnusedNodesOptimizer(GraphRewriterBase): + """Remove the unused nodes in the graph.""" + + def __init__(self, model, input_node_names, output_node_names): + """Initialization.""" + super().__init__(model) + self.input_node_names = input_node_names + self.output_node_names = output_node_names + + @dump_elapsed_time("Pass StripUnusedNodesOptimizer") + def do_transformation(self): + """Execute stripping unused nodes.""" + from neural_compressor.tensorflow.quantization.utils.utility import ( + fix_ref_type_of_graph_def, + strip_unused_nodes, + ) + + self.model = fix_ref_type_of_graph_def(self.model) + return strip_unused_nodes(self.model, self.input_node_names, self.output_node_names) diff --git a/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/switch_optimizer.py b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/switch_optimizer.py new file mode 100644 index 00000000000..c4b94369551 --- /dev/null +++ b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/generic/switch_optimizer.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Switch Graph Rewriter.""" + + +from tensorflow.python.framework import tensor_util + +from neural_compressor.tensorflow.quantization.utils.graph_util import GraphAnalyzer +from neural_compressor.tensorflow.utils import dump_elapsed_time + +from ..graph_base import GraphRewriterBase + + +class SwitchOptimizer(GraphRewriterBase): + """Remove switch op if the input condition is true.""" + + @dump_elapsed_time("Pass SwitchOptimizer") + def do_transformation(self): + """Replace all enter ops whose output is matmul with const. + + Args: + input_graph_def (graphdef): graphdef object + + Returns: + [graphdef]: optimized graph + """ + cur_graph = GraphAnalyzer() + cur_graph.graph = self.model + + graph_info = cur_graph.parse_graph() + target_nodes = cur_graph.query_fusion_pattern_nodes([["Switch"]]) + + for node_combination in target_nodes: + switch_node = graph_info[node_combination[0]].node + pred_node = graph_info[switch_node.input[1]].node + if ( + pred_node.op == "Const" + and tensor_util.MakeNdarray(graph_info[pred_node.name].node.attr["value"].tensor) + ) or ( + pred_node.op == "PlaceholderWithDefault" + and tensor_util.MakeNdarray(graph_info[pred_node.input[0]].node.attr["value"].tensor) + ): + condition = [] + for output in graph_info[node_combination[0]].outputs: + successor_node = graph_info[output].node + for index, value in enumerate(successor_node.input): + if value == node_combination[0] + ":1": + condition.append(True) + elif value == node_combination[0] + ":0": + condition.append(False) + + if not all(condition): + continue + + for output in graph_info[node_combination[0]].outputs: + successor_node = graph_info[output].node + replace_index = None + for index, value in enumerate(successor_node.input): + if value == node_combination[0] + ":1": + replace_index = index + break + if not replace_index: + break + successor_node.input[replace_index] = switch_node.input[0] + switch_node_outputs = list(graph_info[node_combination[0]].outputs) + if switch_node_outputs.index(output) == len(switch_node_outputs) - 1: + cur_graph.remove_node_with_single_input_output(node_combination[0]) + else: + continue + + return cur_graph.dump_graph() diff --git a/neural_compressor/tensorflow/quantization/utils/graph_rewriter/graph_base.py b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/graph_base.py new file mode 100644 index 00000000000..72de4f351c1 --- /dev/null +++ b/neural_compressor/tensorflow/quantization/utils/graph_rewriter/graph_base.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Graph Rewrite Base Class.""" + +import logging +from abc import abstractmethod + + +class GraphRewriterBase: + """Graph Rewrite Base class. + + We abstract this base class and define the interface only. + + Args: + object (model): the input model to be converted. + """ + + def __init__(self, model): + """Initialization.""" + self.model = model + self.logger = logging.getLogger("neural_compressor") + + @abstractmethod + def do_transformation(self): + """Base Interface that need to be implemented by each sub class.""" + raise NotImplementedError diff --git a/neural_compressor/tensorflow/quantization/utils/graph_util.py b/neural_compressor/tensorflow/quantization/utils/graph_util.py new file mode 100644 index 00000000000..3eb99baeccf --- /dev/null +++ b/neural_compressor/tensorflow/quantization/utils/graph_util.py @@ -0,0 +1,1115 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tensorflow Graph Utils Helper Classes.""" + +import copy +import logging +import re +from collections import namedtuple + +import numpy as np +from tensorflow.core.framework import attr_value_pb2, graph_pb2, node_def_pb2 +from tensorflow.python.framework import tensor_util + +from neural_compressor.tensorflow.utils import singleton + +logger = logging.getLogger("neural_compressor") + + +@singleton +class GraphAnalyzer: + """Tensorflow Graph Analyzer class which implemented under singleton mode. + + This class provides the following API: + * Analyze the graph + * Analyze the input/output node names of the specified graph + """ + + # TODO add the positive input flag + node_details = namedtuple("node_details", ["node", "outputs"]) + + def __init__(self, extend_engine=None): + """Initialization. + + Args: + extend_engine: extended engine, for future extension API怂 + """ + self._graph = None + self.extend_engine = extend_engine + + @property + def graph(self): + """Getter of the _graph object. + + Returns: + graph: current graphdef object + """ + return self._graph + + @graph.setter + def graph(self, new_graph): + """Update the internal graph value. + + Args: + new_graph (graphdef object): new model object + """ + self._graph = new_graph + + def _has_positive_input(self, start_node): + """Check the start_node if has positive input.""" + op_type = start_node.op + if op_type in ("Relu", "Relu6") or op_type.find("AndRelu") != -1: + return True + elif op_type.startswith("Quantized") and not op_type.endswith("AndRelu"): + return False + elif op_type in ("Concat", "Add", "AddV2", "AddN"): + for each_input in start_node.input: + has_relu = self._has_positive_input( + self.node_name_details[GraphRewriterHelper.node_name_from_input(each_input)].node + ) + if not has_relu: + return False + return True + elif op_type in ( + "Conv3D", + "Conv2D", + "DepthwiseConv2D", + "QuantizeV2", + "DepthwiseConv2dNative", + "MaxPool", + "MaxPool3D", + "Requantize", + "AvgPool", + "Pad", + "CropAndResize", + "Dequantize", + "Mean", + "MatMul", + "FusedBatchNormV3", + "_MklFusedInstanceNorm", + ): + return self._has_positive_input( + self.node_name_details[GraphRewriterHelper.node_name_from_input(start_node.input[0])].node + ) + else: + return False + + def has_positive_input(self, node_name): + """Check the specified node has positive input data or not. + + Args: + node_name (string): node name + + Returns: + bool: return True if the node has the positive input data, + return False if the node has the negative input data. + """ + return self._has_positive_input(self.node_name_details[node_name].node) + + def get_graph_input_output(self): + """Get the graphdef input/output node names. + + Sometimes, the configuration doesn't specifies the input/output names of the graph, + but tensorflow need to know them clearly to run the graph.We implement this function has the similar + feature like summarize_graph.py which writtern by Google. + + Returns: + tuple: (inputs' name list, outputs'name list) + """ + input_node_names = [] + output_node_names = [] + unlikely_output_types = [ + "Const", + "HostConst", + "Assign", + "NoOp", + "Parameter", + "Assert", + "save", + "global_step", + "read", + "switch", + "cond", + "train", + "init_ops", + "[A-Za-z]+Dataset", + ] + unlikely_input_types = [ + "FIFOQueueV2", + "QueueDequeueV2", + "QueueDequeueUpToV2", + "OneShotIterator", + "IteratorGetNext", + "IteratorV2", + ] + exclude_input_names = [] + extra_input_names = [] + + for _, i in self.node_name_details.items(): + for exclude_input_name in exclude_input_names: + if exclude_input_name == i.node.name: + if i.node.op in unlikely_input_types: + exclude_input_names += i.outputs + else: + extra_input_names.append(i.node.name) + if i.node.op in ["Const", "HostConst", "Variable", "VariableV2"]: + continue + if not i.node.input and not i.outputs: + logger.debug("Skip isolated node {}.".format(i.node.name)) + elif i.node.op == "Placeholder": + input_node_names.append(i.node.name) + elif not i.node.input: + if i.node.op not in unlikely_input_types: + input_node_names.append(i.node.name) + else: + exclude_input_names += i.outputs + elif ( + not i.outputs + and i.node.op not in unlikely_output_types + and not re.match(unlikely_output_types[-1], i.node.op) + ): + output_node_names.append(i.node.name) + else: + pass + + if len(input_node_names) == 0 and len(extra_input_names) != 0: + for extra_input_name in extra_input_names: + input_node_names.append(extra_input_name) + + logger.warning( + "Found possible input node names: {}, output node names: {}.".format(input_node_names, output_node_names) + ) + + return (input_node_names, output_node_names) + + def query_fusion_pattern_nodes(self, patterns=None): + """Public interface for query the nodes aggregation status. + + Args: + patterns (string list): Please check the _search_patterns definition. + + Returns: + [string list]: The matched node names which saved as the string list. + """ + if self.extend_engine: + # Todo keep this for future extension API + pass + else: + return self._search_patterns(patterns) + + def _search_patterns(self, input_pattern): + """Search user specified patterns on internal grpah structure. + + Args: + input_pattern (list): The element of the pattern list could be string/list/tuple. + string or list means the specified types are mandatory while tuple stands for optional. + e.g: + ['Conv2D', ['BiasAdd'], ("Add", "AddN"), ["Relu","Relu6"]] it equals to below patterns: + Conv2D + BiasAdd + Add + Relu + Conv2D + BiasAdd + AddN + Relu + Conv2D + BiasAdd + Add + Relu6 + Conv2D + BiasAdd + AddN + Relu6 + Conv2D + BiasAdd + Relu + Conv2D + BiasAdd + Relu6 + + Return: [string list]. Each matched pattern composed of matched node name and we put the + match node op as the last element of each pair. + e.g + [ + ['resnet_model/conv2d_4/Conv2D', + 'resnet_model/batch_normalization_4/FusedBatchNorm', + 'resnet_model/add', + 'resnet_model/Relu_3', + ['Conv2D', 'BiasAdd', 'Add', 'Relu']], + ['resnet_model/conv2d_7/Conv2D', + 'resnet_model/batch_normalization_7/FusedBatchNorm', + 'resnet_model/add_1', + 'resnet_model/Relu_6', + ['Conv2D', 'BiasAdd', 'AddN', 'Relu6']] + ] + """ + + def _validate_input(data, criteria): + if isinstance(criteria, str) and data == criteria: + return True + + if isinstance(criteria, (list, tuple)) and data in criteria: + return True + + return False + + def _compare_list(list_a, list_b): + """Check list a is a subset of list b. + + e.g, list a is ['a', 'b', 'c'] while list b is ['a', 'b', 'c', 'd'], + then list a is subset of list b. + + Args: + list_a ([Any]): list A + list_b ([Any]): list B + + Returns: + [bool]: list a is a subset of list b or not. + """ + assert isinstance(list_a, list) + assert isinstance(list_b, list) + is_subset = True + + for index, value in enumerate(list_a): + is_subset &= value == list_b[index] + + return is_subset + + def _dfs(op_names, op_types, graph_info, node, pattern): + if pattern == []: + return + start_index = 0 + end_index = len(pattern) - 1 + matched_flag = False + while start_index <= end_index: + matched_flag = _validate_input(node.op, pattern[end_index]) + + if not matched_flag and isinstance(pattern[end_index], tuple): + end_index -= 1 + continue + + if matched_flag: + op_names.append(node.name) + op_types.append(node.op) + break + + return + + if start_index == end_index: + if matched_flag: + matched_res = copy.deepcopy(op_names) + matched_res.reverse() + op_types_copy = copy.deepcopy(op_types) + op_types_copy.reverse() + matched_res.append(op_types_copy) + if matched_res not in output_result: + output_result.append(matched_res) + + op_names.pop() + op_types.pop() + return + + for index, value in enumerate(node.input): + cur_node = graph_info[GraphRewriterHelper.node_name_from_input(value)].node + _dfs(op_names, op_types, graph_info, cur_node, pattern[:end_index]) + if index == len(node.input) - 1: + op_names.pop() + op_types.pop() + + output_result = [] + + for _, v in self.node_name_details.items(): + start_index = len(input_pattern) - 1 + while start_index >= 0: + find_first_match = _validate_input(v.node.op, input_pattern[start_index]) + if find_first_match: + break + + if isinstance(input_pattern[start_index], tuple): + start_index -= 1 + continue + + start_index = -2 + + if start_index < 0: + continue + + visited_op_name = [] + visited_op_types = [] + + _dfs(visited_op_name, visited_op_types, self.node_name_details, v.node, input_pattern) + + sorted_output = sorted(output_result, key=lambda i: i[-1]) + + useless_match_list = [] + for index, value in enumerate(sorted_output): + if index == len(sorted_output) - 1: + break + + next_matched_op_names = sorted_output[index + 1][:-1] + if len(value[:-1]) < len(next_matched_op_names) and _compare_list(value[:-1], next_matched_op_names): + useless_match_list.append(value) + + for i in useless_match_list: + sorted_output.remove(i) + + longest_match = {} + final_output = [] + for i in sorted_output: + key = i[0] + if key not in longest_match: + longest_match[key] = i[-1] + continue + + if len(longest_match[key]) < len(i[-1]): + longest_match[key] = i[-1] + + for i in sorted_output: + if i[0] in longest_match and i[-1] == longest_match[i[0]]: + final_output.append(i) + + return final_output + + def remove_node_with_single_input_output(self, node_name): + """Remove node with one input and rebuild internal graph data structure. + + Args: + node_name (string): node name + + Returns: + [bool]: True if remove the node without exception, + False if failed to remove it. + """ + if node_name not in self.node_name_details: + logger.debug("The {} is not a valid node name.".format(node_name)) + return False + + non_const_node_count = len( + [ + GraphRewriterHelper.node_name_from_input(i) + for i in self.node_name_details[node_name].node.input + if self.node_name_details[GraphRewriterHelper.node_name_from_input(i)].node.op != "Const" + ] + ) + + if non_const_node_count > 1: + logger.debug("The target node {} has more than one input.".format(node_name)) + return False + + try: + top_node_name = GraphRewriterHelper.node_name_from_input(self.node_name_details[node_name].node.input[0]) + + for bottom_node_name in self.node_name_details[node_name].outputs: + update_output_name = [ + bottom_node_name if i == node_name else i for i in self.node_name_details[top_node_name].outputs + ] + self.node_name_details[top_node_name]._replace(outputs=update_output_name) + + update_input_name = [ + self.node_name_details[node_name].node.input[0] if i == node_name else i + for i in self.node_name_details[bottom_node_name].node.input + ] + + if self.node_name_details[bottom_node_name].node.input: + self.node_name_details[bottom_node_name].node.ClearField("input") + self.node_name_details[bottom_node_name].node.input.extend(update_input_name) + + except Exception as e: + logger.debug("Fail to remove node {} due to {}.".format(node_name, str(e))) + return False + else: + return self.remove_node(node_name) + + def remove_node(self, node_name): + """Remove the user specified node by its name. + + Args: + node_name (string): node name string. + + Returns: + [bool]: True if remove the node without exception. + False if failed to remove it. + """ + if node_name not in self.node_name_details: + logger.debug("The {} is not a valid node name.".format(node_name)) + return False + try: + self.node_name_details.pop(node_name) + except Exception as e: + logger.info("Fail to remove {} due to {}.".format(node_name, str(e))) + return False + else: + logger.debug("{} has been removed.".format(node_name)) + return True + + def replace_const_node(self, new_const_node, target_node, old_constant_node_name, replace_all=True): + """Replace the specified const node with another one. + + Args: + new_const_node (NodeDef): node name string. + target_node (list): the string list that contains name of node that + need to be replaced const node. + old_constant_node_name (string): the outdated const node name. + replace_all (bool): replace the specified node name once or not. + """ + new_const_node_name = new_const_node.name + + self.node_name_details[new_const_node_name] = self.node_details(node=new_const_node, outputs=target_node) + + for sub_node in target_node: + if sub_node not in self.node_name_details: + continue + for index, each_node_name in enumerate(self.node_name_details[sub_node].node.input): + if each_node_name + ":0" == old_constant_node_name or each_node_name == old_constant_node_name: + new_input_name = ( + self.node_name_details[sub_node].node.input[:index] + + [new_const_node_name] + + self.node_name_details[sub_node].node.input[index + 1 :] + ) + self.node_name_details[sub_node].node.ClearField("input") + self.node_name_details[sub_node].node.input.extend(new_input_name) + if old_constant_node_name in self.node_name_details: + self.node_name_details[old_constant_node_name].outputs.remove(sub_node) + if len(self.node_name_details[old_constant_node_name].outputs) == 0: + self.remove_node(old_constant_node_name) + if not replace_all: + break + + def replace_constant_graph_with_constant_node(self, new_node, old_end_node_name): + """Remove sub-graph with a const node. + + Args: + new_node (nodedef): the constant node + old_end_node_name (string): the sub-graph end node which will be updated by new node + + Returns: + [bool]: True if remove the node without exception. + False if failed to remove it. + """ + new_node_name = new_node.name + + if new_node.op != "Const": + logger.warning("The input of replace_with_constant_node must be a constant node.") + return False + try: + inputs = self.node_name_details[old_end_node_name].node.input + inputs = [GraphRewriterHelper.node_name_from_input(i) for i in inputs] + for input_name in inputs: + if self.node_name_details[input_name].node.op != "Const": + logger.warning("The subgraph replaces must be constant.") + return False + elif len(self.node_name_details[input_name].outputs) == 1: + self.node_name_details.pop(input_name) + output_node_name = self.node_name_details[old_end_node_name].outputs + self.replace_node(new_node, old_end_node_name, output_node_name) + self.node_name_details[new_node_name].node.ClearField("input") + except Exception as e: + logger.info("Fail to replace {} due to {}.".format(old_end_node_name, str(e))) + return False + else: + return True + + def replace_single_node( + self, new_node, old_output_node_names, old_output_name, old_input_node_names, old_input_name + ): + """Insert one node into the graph. + + Args: + new_node (nodedef): new nodedef object + old_output_node_names (string list):the node names that would be the top node of new + node. + old_output_name (string list): the names that need to be updated with new node name + old_input_node_names (string list): the node names that would be the bottom node of new + node. + old_input_name (string list): the names that need to be updated with new node name + """ + new_node_name = new_node.name + for i in old_output_node_names: + while old_output_name in self.node_name_details[i].outputs: + self.node_name_details[i].outputs.remove(old_output_name) + self.node_name_details[i].outputs.append(new_node_name) + + self.node_name_details[new_node_name] = self.node_details(node=new_node, outputs=old_input_node_names) + + for each_input_node_name in old_input_node_names: + for index, each_node_name in enumerate(self.node_name_details[each_input_node_name].node.input): + if self.node_name_details[each_input_node_name].node.input and (each_node_name) == old_input_name: + new_input_name = ( + self.node_name_details[each_input_node_name].node.input[:index] + + [new_node_name] + + self.node_name_details[each_input_node_name].node.input[index + 1 :] + ) + self.node_name_details[each_input_node_name].node.ClearField("input") + self.node_name_details[each_input_node_name].node.input.extend(new_input_name) + + def replace_node(self, new_node, old_node_name, output_nodes_name): + """Replace the node into the internal data structure node_name_details. + + Args: + new_node (nodedef): the nodedef object. + old_node_name (string): the parent node of input node. + output_nodes_name (string list): output node names list + """ + new_node_name = new_node.name + self.node_name_details[new_node_name] = self.node_details(node=new_node, outputs=output_nodes_name) + old_node = self.node_name_details[old_node_name].node + for input_node_name in old_node.input: + if input_node_name in self.node_name_details: + self.node_name_details[input_node_name].outputs.remove(old_node_name) + self.node_name_details[input_node_name].outputs.append(new_node_name) + + for node_name in output_nodes_name: + for index, each_node_name in enumerate(self.node_name_details[node_name].node.input): + if ( + self.node_name_details[node_name].node.input + and GraphRewriterHelper.node_name_from_input(each_node_name) == old_node_name + ): + new_input_name = ( + self.node_name_details[node_name].node.input[:index] + + [new_node_name] + + self.node_name_details[node_name].node.input[index + 1 :] + ) + self.node_name_details[node_name].node.ClearField("input") + self.node_name_details[node_name].node.input.extend(new_input_name) + self.remove_node(old_node_name) + + def add_node(self, new_node, start_node_name, end_node_names): + """Add the node into the internal data structure node_name_details. + + Args: + new_node (nodedef): the nodedef object. + start_node_name (string): the parent node of input node. + end_node_names (string list): output node names list + """ + new_node_name = new_node.name + + if new_node_name in self.node_name_details: + logger.debug("Remove the existed node {} from internal data structure.".format((new_node_name))) + self.node_name_details.pop(new_node_name) + + self.node_name_details[new_node_name] = self.node_details(node=new_node, outputs=end_node_names) + + for end_node_name in end_node_names: + # Update start node's output info + if end_node_name not in self.node_name_details: + continue + if ( + start_node_name + and end_node_name + in self.node_name_details[GraphRewriterHelper.node_name_from_input(start_node_name)].outputs + ): + self.node_name_details[GraphRewriterHelper.node_name_from_input(start_node_name)].outputs.remove( + end_node_name + ) + + # reset output node's input + for index, each_node_name in enumerate(self.node_name_details[end_node_name].node.input): + if each_node_name == start_node_name: + new_input_name = ( + self.node_name_details[end_node_name].node.input[:index] + + [new_node_name] + + self.node_name_details[end_node_name].node.input[index + 1 :] + ) + self.node_name_details[end_node_name].node.ClearField("input") + self.node_name_details[end_node_name].node.input.extend(new_input_name) + + # add the inserted node into the start node's output. + if start_node_name: + self.node_name_details[GraphRewriterHelper.node_name_from_input(start_node_name)].outputs.append( + new_node_name + ) + + def dump_graph(self): + """Dump the current model's graphdef. + + Returns: + [graphdef]: A graphdef object + """ + output_graph_def = graph_pb2.GraphDef() + for _, v in self.node_name_details.items(): + output_graph_def.node.extend([v.node]) + + return output_graph_def + + def get_frame_info(self): + """Get the frame info of the model. + + Returns: + [parent_frame_details]: OrderedDict frame info of the graph nodes. + """ + from collections import OrderedDict + + self.parent_frame_details = OrderedDict() + input_node_names, _ = self.get_graph_input_output() + + traverse_list = copy.deepcopy(input_node_names) + visited = [] + + while traverse_list: + node_name = traverse_list.pop(0) + node_details = self.node_name_details[node_name] + + if node_details.node.name in visited: + continue + + for output in node_details.outputs: + traverse_list.append(output) + + inputs = node_details.node.input + if not inputs: + self.parent_frame_details[node_details.node.name] = None + if self.node_name_details[output].node.op == "Enter": + self.parent_frame_details[output] = self.node_name_details[output].node + elif self.node_name_details[output].node.op == "Exit": + self.parent_frame_details[output] = None + else: + if output in self.parent_frame_details and self.parent_frame_details[output]: + if ( + node_details.node.name in self.parent_frame_details + and self.parent_frame_details[node_details.node.name] + ): + assert ( + self.parent_frame_details[output].attr["frame_name"] + == self.parent_frame_details[node_details.node.name].attr["frame_name"] + ) + else: + if node_details.node.name in self.parent_frame_details: + self.parent_frame_details[output] = self.parent_frame_details[node_details.node.name] + + visited.append(node_details.node.name) + return self.parent_frame_details + + def parse_graph(self, input_graph_def=None): + """Analyze the input graphdef and return the list contains each node's input/outputnode names. + + Args: + input_graph_def ([graphdef]): graphdef object + + Returns: + [list]: A list contains each node's inputs/outputs info. + """ + if not input_graph_def: + input_graph_def = self._graph + + self.node_name_details = {} + + for node in input_graph_def.node: + node_name = GraphRewriterHelper.node_name_from_input(node.name) + + each_node = self.node_details(node=node, outputs=[]) + + if node_name not in self.node_name_details: + self.node_name_details[node_name] = each_node + + for node_name, node_details in self.node_name_details.items(): + # update the upper node's output information. + for each_input in node_details.node.input: + self.node_name_details[GraphRewriterHelper.node_name_from_input(each_input)].outputs.append(node_name) + + return self.node_name_details + + +class GraphRewriterHelper: + """Encapsulates the graph operation into one class.""" + + node_name_cache = {} + node_name_port_cache = {} + + @staticmethod + def compare_node_attr(node_a, node_b): + """Compare two node has identical attributes or not. + + Args: + node_a (nodedef): Input node. + node_b (nodedef): Another node to be compared. + + Returns: + [bool]: True if two node have the identical attributes. + """ + if len(node_a.input) > 1: + return False + + if node_a.input != node_b.input: + return False + + if node_a.op != node_b.op: + return False + + if len(node_a.attr) != len(node_b.attr): + return False + + node_a_attr = sorted(list(node_a.attr)) + node_b_attr = sorted(list(node_b.attr)) + + if node_a_attr != node_b_attr: + return False + + for attr_name in node_a_attr: + if node_a.attr[attr_name] != node_b.attr[attr_name]: + return False + + return True + + @staticmethod + def create_node(op, name, inputs): + """Create a nodedef object. + + Args: + op (string): op type + name (string): op name + inputs (string list): op's inputs name + + Returns: + nodedef: the created nodedef object + """ + new_node = node_def_pb2.NodeDef() + new_node.op = op + new_node.name = name + for input_name in inputs: + new_node.input.extend([input_name]) + return new_node + + @staticmethod + def create_constant_node(name, value, dtype, shape=None, device="cpu"): + """Create constant node. + + Args: + name (string): op name + value (np.array): input data + dtype (datatype): data type of the input value + shape (int list, optional): the value's shape. Defaults to None. + device (str, optional): the device type, it may be the 'cpu' or 'gpu'. + Defaults to 'cpu'. + + Returns: + [type]: [description] + """ + node = GraphRewriterHelper.create_node("Const" if device == "cpu" else "HostConst", name, []) + GraphRewriterHelper.set_attr_dtype(node, "dtype", dtype) + GraphRewriterHelper.set_attr_tensor(node, "value", value, dtype, shape) + return node + + @staticmethod + def set_attr_dtype(node, key, value): + """Set the attribute data type.""" + node.attr[key].CopyFrom(attr_value_pb2.AttrValue(type=value.as_datatype_enum)) + + @staticmethod + def set_attr_tensor(node, key, value, dtype, shape=None): + """Set the tensor value to specified attribute field. + + Args: + node (nodedef): the target nodedef object + key (string): attribute name + value (np.array): the content + dtype (dtypes): data type + shape (int list, optional): the input tensor's shape. Defaults to None. + """ + node.attr[key].CopyFrom( + attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(value, dtype=dtype, shape=shape)) + ) + + @staticmethod + def set_attr_type_list(node, key, value): + """Set the node's attr which data type is int list.""" + list_value = attr_value_pb2.AttrValue.ListValue(type=value) + node.attr[key].CopyFrom(attr_value_pb2.AttrValue(list=list_value)) + + @staticmethod + def set_attr_string_list(node, key, value): + """Set the node's attr which data type is int list.""" + list_value = attr_value_pb2.AttrValue.ListValue(s=value) + node.attr[key].CopyFrom(attr_value_pb2.AttrValue(list=list_value)) + + @staticmethod + def set_attr_string(node, key, value): + """Set the node's attr which data type is string.""" + node.attr[key].CopyFrom(attr_value_pb2.AttrValue(s=value)) + + @staticmethod + def set_attr_int_list(node, key, value): + """Set the node's attr which data type is int list.""" + list_value = attr_value_pb2.AttrValue.ListValue(i=value) + node.attr[key].CopyFrom(attr_value_pb2.AttrValue(list=list_value)) + + @staticmethod + def set_attr_int(node, key, value): + """Set the node's attr which data type is int.""" + node.attr[key].CopyFrom(attr_value_pb2.AttrValue(i=value)) + + @staticmethod + def set_attr_float(node, key, value): + """Set the node's attr which data type is float.""" + node.attr[key].CopyFrom(attr_value_pb2.AttrValue(f=value)) + + @staticmethod + def set_attr_bool(node, key, value): + """Set the node's attr which data type is bool.""" + node.attr[key].CopyFrom(attr_value_pb2.AttrValue(b=value)) + + @staticmethod + def node_name_from_input(node_name): + """Static method that get the valid node name from input name. + + Args: + node_name (string): node name defined in the input field. + + Returns: + string: node's name + """ + if node_name not in GraphRewriterHelper.node_name_cache: + key = node_name + if node_name.startswith("^"): + node_name = node_name[1:] + m = re.search(r"(.*):\d+$", node_name) + if m: + node_name = m.group(1) + GraphRewriterHelper.node_name_cache[key] = node_name + return node_name + + return GraphRewriterHelper.node_name_cache[node_name] + + @staticmethod + def values_from_const(node_def): + """Extracts the values from a const NodeDef as a numpy ndarray. + + Args: + node_def: Const NodeDef that has the values we want to access. + + Returns: + Numpy ndarray containing the values. + + Raises: + ValueError: If the node isn't a Const. + """ + assert node_def.op == "Const", "Node named '%s' should be a Const op." % node_def.name + + input_tensor = node_def.attr["value"].tensor + tensor_value = tensor_util.MakeNdarray(input_tensor) + return tensor_value + + @staticmethod + def generate_int32_bias_for_conv( + bias_tensor, + channel_size, + max_input, + min_input, + max_filter_tensor, + min_filter_tensor, + activation_range, + weights_range=127.0, + ): + """Static method that generate int32 bias for conv op. + + Args: + bias_tensor: bias node tensor. + channel_size: channel size. + max_input: max activation input value. + min_input: min activation input value. + max_filter_tensor: max weight input tensor. + min_filter_tensor: min weight input tensor. + activation_range: activation range value. + weights_range: weight range value. + + Returns: + int32_bias: int32 bias + """ + bias_length = bias_tensor.shape[0] + scales = [] + if len(max_filter_tensor) > 1: + for i in range(channel_size): + scales.append( + activation_range + * weights_range + / (max(abs(max_input), abs(min_input)) * max(abs(max_filter_tensor[i]), abs(min_filter_tensor[i]))) + ) + else: + for i in range(channel_size): + scales.append( + activation_range + * weights_range + / (max(abs(max_input), abs(min_input)) * max(abs(max_filter_tensor[0]), abs(min_filter_tensor[0]))) + ) + int32_bias = [] + if channel_size > 1: + for i in range(bias_length): + int32_bias.append((int)(np.around(bias_tensor[i] * scales[i]))) + else: + for i in range(bias_length): + int32_bias.append((int)(np.around(bias_tensor[i] * scales[0]))) + + return int32_bias + + @staticmethod + def generate_int32_bias_for_matmul( + bias_tensor, + weights_tensor, + input_range, + max_input, + min_input, + max_filter_value, + min_filter_value, + ): + """Static method that generate int32 bias for matmul op. + + Args: + bias_tensor: bias node tensor. + weights_tensor: weights tensor. + input_range: activation range value. + max_input: max activation input value. + min_input: min activation input value. + max_filter_tensor: max weight input tensor. + min_filter_tensor: min weight input tensor. + + Returns: + int32_bias: int32 bias + """ + bias_scale = 255.0 * 127.0 / (input_range * max(abs(max_filter_value), abs(min_filter_value))) + relative_scale = 255 * min_input / (max_input - min_input) + int32_bias = [] + for bias_index, value in enumerate(np.sum(np.array(weights_tensor, dtype=np.int32), axis=0, dtype=np.int32)): + if bias_index >= bias_tensor.size: + continue + int32_bias.append(int(np.around(bias_tensor[bias_index] * bias_scale + value * relative_scale))) + + return int32_bias + + @staticmethod + def generate_int32_bias_for_matmul_per_channel( + bias_tensor, + weights_tensor, + max_input, + min_input, + max_filter_tensor, + min_filter_tensor, + ): # pragma: no cover + """Static method that generate per-channel int32 bias for matmul op. + + Args: + bias_tensor: bias node tensor. + weights_tensor: weights tensor. + max_input: max activation input value. + min_input: min activation input value. + max_filter_tensor: max weight input tensor. + min_filter_tensor: min weight input tensor. + + Returns: + int32_bias: int32 bias + """ + channel_size = bias_tensor.shape[0] + activation_range = 255.0 + weights_range = 127.0 + scales = [] + relative_scale = 255 * min_input / (max_input - min_input) + for i in range(channel_size): + scales.append( + activation_range + * weights_range + / ((max_input - min_input) * max(abs(max_filter_tensor[i]), abs(min_filter_tensor[i]))) + ) + int32_bias = [] + for i in range(channel_size): + value = np.sum(np.array(weights_tensor), axis=0, dtype=np.int32)[i] + int32_bias.append((int)(np.around(value * relative_scale + bias_tensor[i] * scales[i]))) + + return int32_bias + + @staticmethod + def gen_valid_sampling_log(log_path): + """Generate the valid sampling log. + + Args: + log_path: the valid sampling log file path. + + Returns: + the sampling min max value. + """ + + def gen_per_iter(data): + res = [] + requant_tmp = [] + for i in data: + if i.find("__print__;__requant_") == -1: + res.append(i) + else: + requant_tmp.append(i) + sorted_requant = sorted(requant_tmp) + odd_list = sorted_requant[::2] + even_list = sorted_requant[1::2] + for index, value in enumerate(even_list): + min_value = min(0, float(value.split(":")[1][1:-1])) + max_value = float(odd_list[index].split(":")[1][1:-1]) + max_value = max_value if max_value > min_value else min_value + 1e-05 + mixed_str = value.split(":")[0] + "_max:[" + str(min_value) + "][" + str(max_value) + "]" + + res.append(mixed_str) + return res + + def separate(line): + """This function is to separate the strings. + + Example: + ';slice__print__;__max:[1];slice__print__;__min:[-1]' --> + [';slice__print__;__max:[1]', ';slice__print__;__min:[-1]'] + """ + separated_lines = [] + for subline in line.split("];"): + if not subline.startswith(";"): + subline = ";" + subline + if not subline.endswith("]"): + subline += "]" + separated_lines.append(subline) + return separated_lines + + with open(log_path) as f: + valid_data = [] + for i in f.readlines(): + if not i.startswith(";"): + continue + line = i.strip() + if line.find("];") != 0: + separated_lines = separate(line) + valid_data += separated_lines + else: + valid_data.append(line) + + first_line = valid_data[0].rsplit(":")[0] + + iterations = 0 + for i in valid_data: + if i.startswith(first_line): + iterations += 1 + + step = int(len(valid_data) / iterations) + if step % 2 == 1: + step -= 1 + iterations = int(len(valid_data) / step) + int(len(valid_data) % step > 0) + + final_res = [] + + for i in range(iterations): + final_res.extend(gen_per_iter(valid_data[int(i * step) : int(step * (i + 1))])) + if i + 1 == iterations and int(step * (i + 1)) < len(valid_data): + final_res.extend(gen_per_iter(valid_data[int(step * (i + 1)) : len(valid_data)])) + + return final_res + + @staticmethod + def analysis_rnn_model(graph_def, bf16_ops=[], fp32_ops=[]): + """Match the RNN and dynamic RNN patterns.""" + g = GraphAnalyzer() + g.graph = graph_def + graph_info = g.parse_graph() + rnn_pattern = [["TensorArrayV3"], ["Enter"], ["TensorArrayReadV3"], ["MatMul"], ["BiasAdd"]] + target_nodes = g.query_fusion_pattern_nodes(rnn_pattern) + res = {} + for i in target_nodes: + if i[-3] not in bf16_ops and i[-3] not in fp32_ops: + res[(i[-3], i[-2])] = graph_info[i[1]].node.attr["frame_name"].s.decode() + + dynamic_rnn_pattern = [["Enter"], ["MatMul"], ["BiasAdd"]] + target_nodes = g.query_fusion_pattern_nodes(dynamic_rnn_pattern) + for i in target_nodes: + if i[-3] not in bf16_ops and i[-3] not in fp32_ops: + res[(i[1], i[2])] = graph_info[i[0]].node.attr["frame_name"].s.decode() + + return res diff --git a/neural_compressor/tensorflow/quantization/utils/quantize_graph_common.py b/neural_compressor/tensorflow/quantization/utils/quantize_graph_common.py new file mode 100644 index 00000000000..9a210e695ff --- /dev/null +++ b/neural_compressor/tensorflow/quantization/utils/quantize_graph_common.py @@ -0,0 +1,448 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Quantize Graph Common Utils Herlper Class.""" + +import re + +import numpy as np +from tensorflow.core.framework import attr_value_pb2, graph_pb2, node_def_pb2 +from tensorflow.python.framework import dtypes, tensor_util + + +class QuantizeGraphHelper: + """This class contains several staticmethod functions.""" + + node_name_cache = {} + node_name_port_cache = {} + + def __init__(self): + """Initialization.""" + pass + + def _recursive_graph_sorting(self, node_name): + """Recursive sort the graph.""" + if node_name in self.op_list or not self.node_name_mapping[node_name].input: + return + + for input_name in self.node_name_mapping[node_name].input: + if input_name not in self.node_name_mapping: + continue + else: + self._recursive_graph_sorting((input_name)) + + if node_name not in self.op_list: + self.op_list.append(node_name) + + return + + def _get_op_list(self, output_node_names): + """Get op list by recursive sorting the graph.""" + for output_name in output_node_names: + self._recursive_graph_sorting(output_name) + + def get_sorted_graph(self, input_graph, input_node_names, output_node_names): + """Return a sorted graphdef object. + + Sometimes the input graphdef was composed of the random nodedef objects, + we reorder the graph to make the parsing easier. + + Args: + input_graph (graphdef]): the input graphdef object + input_node_names (string list): the input node names + output_node_names (string list): the output node names + + Returns: + [type]: [description] + """ + self.node_name_mapping = {} + self.op_list = [input_node_name for input_node_name in input_node_names] + for node in input_graph.node: + self.node_name_mapping[node.name] = node + self._get_op_list(output_node_names) + + all_ops = [i for i in list(self.node_name_mapping.keys()) if i not in self.op_list] + self.op_list.extend(sorted(set(all_ops), key=all_ops.index)) + + self.out_graph_def = graph_pb2.GraphDef() + for i in self.op_list: + new_node = node_def_pb2.NodeDef() + new_node.CopyFrom(self.node_name_mapping[i]) + self.out_graph_def.node.extend([new_node]) + + return self.out_graph_def + + @staticmethod + def split_shared_inputs(input_graph_def): + """Split shared inputs(like weights and bias) of the graph. + + :param in_graph: input graph file. + :return: path to output graph file. + """ + node_map = {} + for node in input_graph_def.node: + if node.name not in node_map: + node_map[node.name] = node + + output_graph_def = graph_pb2.GraphDef() + is_shared_input = False + # map of: input_name - op_name + input_map = {} + for node_name in node_map.keys(): + node = node_map[node_name] + for input_idx, input_node_name in enumerate(node.input): + if node_map[QuantizeGraphHelper.node_name_from_input(input_node_name)].op == "Const": + # is shared and current node is not the first one + # sharing the input + if input_node_name in input_map.keys(): + is_shared_input = True + input_map[input_node_name].append(node.name) + new_input_node = node_def_pb2.NodeDef() + new_input_node.CopyFrom(node_map[input_node_name]) + new_input_node.name = input_node_name + "_" + str(len(input_map[input_node_name])) + node.input[input_idx] = new_input_node.name + output_graph_def.node.extend([new_input_node]) + else: + input_map[input_node_name] = [node.name] + output_graph_def.node.extend([node]) + + return output_graph_def if is_shared_input else input_graph_def + + @staticmethod + def remove_training_nodes(input_graph, protected_nodes=[], types_to_splice=["Identity", "CheckNumerics"]): + """Prunes out nodes that aren't needed for inference. + + Args: + input_graph: Model to analyze and prune. + types_to_splice: An optional list of types of nodes to be removed + unconditionally. + + Returns: + A optimized graphdef object. + """ + input_nodes = input_graph.node + + control_input_names = set() + node_names_with_control_input = set() + for node in input_nodes: + for node_input in node.input: + if "^" in node_input: + control_input_names.add(node_input.replace("^", "")) + node_names_with_control_input.add(node.name) + + names_to_splice = {} + for node in input_nodes: + if node.op in types_to_splice: + # We don't want to remove nodes that have control edge inputs, because + # they might be involved in subtle dependency issues that removing them + # will jeopardize. + if node.name not in node_names_with_control_input: + names_to_splice[node.name] = node.input[0] + + # We also don't want to remove nodes which are used as control edge inputs. + names_to_splice = {name: value for name, value in names_to_splice.items() if name not in control_input_names} + + nodes_after_splicing = [] + + for node in input_nodes: + if node.name in names_to_splice and node.name not in protected_nodes: + continue + + if node.name in protected_nodes and node.name in types_to_splice: + nodes_after_splicing.append(node) + continue + + new_node = node_def_pb2.NodeDef() + new_node.CopyFrom(node) + input_before_removal = node.input + del new_node.input[:] + for full_input_name in input_before_removal: + input_name = re.sub(r"^\^", "", full_input_name) + while input_name in names_to_splice: + full_input_name = names_to_splice[input_name] + input_name = re.sub(r"^\^", "", full_input_name) + new_node.input.append(full_input_name) + nodes_after_splicing.append(new_node) + + output_graph = graph_pb2.GraphDef() + output_graph.node.extend(nodes_after_splicing) + return output_graph + + @staticmethod + def create_node(op, name, inputs): + """Create a nodedef object. + + Args: + op (string): op type + name (string): op name + inputs (string list): op's inputs name + + Returns: + nodedef: the created nodedef object + """ + new_node = node_def_pb2.NodeDef() + new_node.op = op + new_node.name = name + for input_name in inputs: + new_node.input.extend([input_name]) + return new_node + + @staticmethod + def create_constant_node(name, value, dtype, shape=None, device="cpu"): + """Create constant node. + + Args: + name (string): op name + value (np.array): input data + dtype (datatype): data type of the input value + shape (int list, optional): the value's shape. Defaults to None. + device (str, optional): the device type, it may be the 'cpu' or 'gpu'. + Defaults to 'cpu'. + + Returns: + [type]: [description] + """ + node = QuantizeGraphHelper.create_node("Const" if device == "cpu" else "HostConst", name, []) + QuantizeGraphHelper.set_attr_dtype(node, "dtype", dtype) + QuantizeGraphHelper.set_attr_tensor(node, "value", value, dtype, shape) + return node + + @staticmethod + def copy_attr(node, key, attr_value): + """Copy the specified attr value to node. + + Args: + node (nodedef): a nodedef object + key (string): string name + attr_value (any): the specified attribute value + """ + node.attr[key].CopyFrom(attr_value) + + @staticmethod + def set_attr_dtype(node, key, value): + """Set the attribute data type.""" + node.attr[key].CopyFrom(attr_value_pb2.AttrValue(type=value.as_datatype_enum)) + + @staticmethod + def set_attr_tensor(node, key, value, dtype, shape=None): + """Set the tensor value to specified attribute field. + + Args: + node (nodedef): the target nodedef object + key (string): attribute name + value (np.array): the content + dtype (dtypes): data type + shape (int list, optional): the input tensor's shape. Defaults to None. + """ + node.attr[key].CopyFrom( + attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(value, dtype=dtype, shape=shape)) + ) + + @staticmethod + def set_attr_string_list(node, key, value): + """Set the node's attr which data type is int list.""" + list_value = attr_value_pb2.AttrValue.ListValue(s=value) + node.attr[key].CopyFrom(attr_value_pb2.AttrValue(list=list_value)) + + @staticmethod + def set_attr_type_list(node, key, value): + """Set the node's attr which data type is int list.""" + list_value = attr_value_pb2.AttrValue.ListValue(type=value) + node.attr[key].CopyFrom(attr_value_pb2.AttrValue(list=list_value)) + + @staticmethod + def set_attr_string(node, key, value): + """Set the node's attr which data type is string.""" + node.attr[key].CopyFrom(attr_value_pb2.AttrValue(s=value)) + + @staticmethod + def set_attr_bool(node, key, value): + """Set the node's attr which data type is bool.""" + node.attr[key].CopyFrom(attr_value_pb2.AttrValue(b=value)) + + @staticmethod + def set_attr_int(node, key, value): + """Set the node's attr which data type is int.""" + node.attr[key].CopyFrom(attr_value_pb2.AttrValue(i=value)) + + @staticmethod + def set_attr_float(node, key, value): + """Set the node's attr which data type is float.""" + node.attr[key].CopyFrom(attr_value_pb2.AttrValue(f=value)) + + @staticmethod + def node_name_from_input(node_name): + """Static method that get the valid node name from input name. + + Args: + node_name (string): node name defined in the input field. + + Returns: + string: node's name + """ + if node_name not in QuantizeGraphHelper.node_name_cache: + key = node_name + if node_name.startswith("^"): + node_name = node_name[1:] + m = re.search(r"(.*):\d+$", node_name) + if m: + node_name = m.group(1) + QuantizeGraphHelper.node_name_cache[key] = node_name + return node_name + + return QuantizeGraphHelper.node_name_cache[node_name] + + @staticmethod + def unique_node_name_from_input(node_name): + """Get the node name from other node name's input field.""" + return node_name.replace(":", "__port__").replace("^", "__hat__") + + @staticmethod + def ensure_tensor_name_has_port(node_name): + """Makes sure that a tensor name has :0 if no explicit port exists.""" + if node_name not in QuantizeGraphHelper.node_name_port_cache: + key = node_name + m = re.search(r"(.*):\d+$", node_name) + if not m: + node_name = node_name + ":0" + QuantizeGraphHelper.node_name_port_cache[key] = node_name + return node_name + + return QuantizeGraphHelper.node_name_port_cache[node_name] + + @staticmethod + def generate_quantized_weight_node( + host_op_type, input_node, per_channel, weight_bit=7.0, device="cpu", enter_node=None + ): + """Generated the quantized weight node.""" + base_name = input_node.name + "_" + qint8_const_name = base_name + "qint8_const" + min_name = base_name + "min" + max_name = base_name + "max" + float_tensor = tensor_util.MakeNdarray(input_node.attr["value"].tensor) + epsilon = 1e-4 # Needs to be set empirically if accuracy is not satisfactory + range_coefficent = 127 / (2**weight_bit - 1) + if host_op_type in ( + "Conv2D", + "MatMul", + "Conv3D", + "BatchMatMulV2", + "Conv2DBackpropInput", + "Conv3DBackpropInputV2", + ): + if per_channel: + if host_op_type in ("Conv3D", "Conv3DBackpropInputV2"): + ranges = np.abs(float_tensor).max(axis=(0, 1, 2, 3)) + elif host_op_type in ("Conv2D", "Conv2DBackpropInput"): + ranges = np.abs(float_tensor).max(axis=(0, 1, 2)) + elif host_op_type in ("MatMul"): + if "transpose_b" in input_node.attr and input_node.attr["transpose_b"].b: # pragma: no cover + ranges = np.abs(float_tensor).max(axis=(1)) + else: + ranges = np.abs(float_tensor).max(axis=(0)) + else: + ranges = np.abs(float_tensor).max(axis=(0, 1, 2)) + + ranges *= range_coefficent + min_value = -ranges + max_value = ranges + # nudging min-max values outside epsilon radius around zero + ranges[ranges < epsilon] = epsilon + min_value[np.abs(min_value) < epsilon] = -epsilon + max_value[np.abs(max_value) < epsilon] = epsilon + if "transpose_b" in input_node.attr and input_node.attr["transpose_b"].b: # pragma: no cover + # transpose for broadcasting + float_tensor = np.transpose(float_tensor, [1, 0]) + qint8_tensor = (np.around(float_tensor * 127.0 / ranges)).astype(np.int8) + qint8_tensor = np.transpose(qint8_tensor, [1, 0]) + else: + qint8_tensor = (np.around(float_tensor * 127.0 / ranges)).astype(np.int8) + else: + min_value = np.min(float_tensor) + max_value = np.max(float_tensor) + min_value *= range_coefficent + max_value *= range_coefficent + min_value = min(min_value, 0.0) + if min_value == max_value: + if abs(min_value) < 0.000001: + max_value = min_value + 1.0 + elif min_value > 0: + max_value = 2 * min_value + else: + max_value = min_value / 2.0 + range_value = np.max(np.abs([min_value, max_value])) + qint8_tensor = (np.around(float_tensor * 127.0 / range_value)).astype(np.int8) + qint8_tensor = np.clip(qint8_tensor, -127, 127).astype(np.int8) + min_value = -range_value + max_value = range_value + elif host_op_type == "DepthwiseConv2dNative": + # get the max values based on dim 0 and 1 for depthwise conv + # since, the output channel will be dim 2 * dim 3 + ranges = np.abs(float_tensor).max(axis=(0, 1)) + ranges = ranges.flatten() + min_value = -ranges + max_value = ranges + # nudging min-max values outside epsilon radius around zero + ranges[ranges < epsilon] = epsilon + min_value[np.abs(min_value) < epsilon] = -epsilon + max_value[np.abs(max_value) < epsilon] = epsilon + # Since output channel will be 1 dim which is dim 2 * dim 3 + # When divide by range, qint8_tensor needs to be 3 dim + # where, 3rd dim should be same dim of ranges + a, b, c, d = float_tensor.shape + qint8_tensor = (np.around(float_tensor.reshape(a, b, c * d) * 127.0 / ranges)).astype(np.int8) + # get the shape back to 4 dim + qint8_tensor = qint8_tensor.reshape(a, b, c, d) + shape = tensor_util.TensorShapeProtoToList(input_node.attr["value"].tensor.tensor_shape) + qint8_const_node = QuantizeGraphHelper.create_constant_node( + qint8_const_name, qint8_tensor, dtypes.qint8, shape=shape + ) + min_node = QuantizeGraphHelper.create_constant_node(min_name, min_value, dtypes.float32, device="cpu") + + max_node = QuantizeGraphHelper.create_constant_node(max_name, max_value, dtypes.float32, device="cpu") + + qint8_const_enter_node = None + min_enter_node = None + max_enter_node = None + + if enter_node: + qint8_const_enter_node = QuantizeGraphHelper.create_node( + "Enter", qint8_const_name + "_enter", [qint8_const_name] + ) + QuantizeGraphHelper.set_attr_string(qint8_const_enter_node, "frame_name", enter_node.attr["frame_name"].s) + QuantizeGraphHelper.set_attr_dtype(qint8_const_enter_node, "T", dtypes.qint8) + QuantizeGraphHelper.set_attr_bool(qint8_const_enter_node, "is_constant", True) + QuantizeGraphHelper.set_attr_int( + qint8_const_enter_node, "parallel_iterations", enter_node.attr["parallel_iterations"].i + ) + + min_enter_node = QuantizeGraphHelper.create_node("Enter", min_name + "_enter", [min_name]) + QuantizeGraphHelper.set_attr_string(min_enter_node, "frame_name", enter_node.attr["frame_name"].s) + QuantizeGraphHelper.set_attr_dtype(min_enter_node, "T", dtypes.float32) + QuantizeGraphHelper.set_attr_bool(min_enter_node, "is_constant", True) + QuantizeGraphHelper.set_attr_int( + min_enter_node, "parallel_iterations", enter_node.attr["parallel_iterations"].i + ) + + max_enter_node = QuantizeGraphHelper.create_node("Enter", max_name + "_enter", [max_name]) + QuantizeGraphHelper.set_attr_string(max_enter_node, "frame_name", enter_node.attr["frame_name"].s) + QuantizeGraphHelper.set_attr_dtype(max_enter_node, "T", dtypes.float32) + QuantizeGraphHelper.set_attr_bool(max_enter_node, "is_constant", True) + QuantizeGraphHelper.set_attr_int( + max_enter_node, "parallel_iterations", enter_node.attr["parallel_iterations"].i + ) + + return qint8_const_node, min_node, max_node, qint8_const_enter_node, min_enter_node, max_enter_node diff --git a/neural_compressor/tensorflow/quantization/utils/utility.py b/neural_compressor/tensorflow/quantization/utils/utility.py new file mode 100644 index 00000000000..ac6ceb54324 --- /dev/null +++ b/neural_compressor/tensorflow/quantization/utils/utility.py @@ -0,0 +1,755 @@ +# +# -*- coding: utf-8 -*- +# +# Copyright (c) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Tensorflow Utils Helper functions.""" + +import os +from collections import OrderedDict, UserDict + +import numpy as np +import tensorflow as tf +from google.protobuf import text_format +from tensorflow.core.framework import attr_value_pb2, graph_pb2, node_def_pb2, variable_pb2 +from tensorflow.core.protobuf import config_pb2, meta_graph_pb2 +from tensorflow.python.eager import context, wrap_function +from tensorflow.python.framework import convert_to_constants +from tensorflow.python.grappler import tf_optimizer +from tensorflow.python.platform import gfile +from tensorflow.python.saved_model import load, save, signature_constants, tag_constants +from tensorflow.python.training import saver +from tensorflow.python.util import nest + +from neural_compressor.common import logger +from neural_compressor.tensorflow.quantization.utils.graph_util import GraphAnalyzer, GraphRewriterHelper + + +def disable_random(seed=1): + """A Decorator to disable tf random seed.""" + + def decorator(func): + def wrapper(*args, **kw): + tf.compat.v1.disable_eager_execution() + tf.compat.v1.reset_default_graph() + tf.compat.v1.set_random_seed(seed) + return func(*args, **kw) + + return wrapper + + return decorator + + +def read_graph(in_graph, in_graph_is_binary=True): + """Reads input graph file as GraphDef. + + :param in_graph: input graph file. + :param in_graph_is_binary: whether input graph is binary, default True. + :return: input graphDef. + """ + assert gfile.Exists(in_graph), "Input graph pb file %s does not exist." % in_graph + + input_graph_def = graph_pb2.GraphDef() + mode = "rb" if in_graph_is_binary else "r" + with gfile.Open(in_graph, mode) as f: + data = f.read() + if in_graph_is_binary: + input_graph_def.ParseFromString(data) + else: + text_format.Merge(data, input_graph_def) + + return input_graph_def + + +def write_graph(out_graph_def, out_graph_file): + """Write output graphDef to file. + + :param out_graph_def: output graphDef. + :param out_graph_file: path to output graph file. + :return: None. + """ + assert isinstance(out_graph_def, tf.compat.v1.GraphDef), "out_graph_def is not instance of TensorFlow GraphDef." + + assert out_graph_file and os.path.exists( + os.path.dirname(out_graph_file) + ), '"output_graph" directory does not exists.' + + f = gfile.GFile(out_graph_file, "wb") + f.write(out_graph_def.SerializeToString()) + + +def is_ckpt_format(model_path): + """Check the model_path format is ckpt or not. + + Args: + model_path (string): the model folder path + + Returns: + string: return the ckpt prefix if the model_path contains ckpt format data else None. + """ + file_list = [os.path.splitext(i)[-1] for i in os.listdir(model_path)] + if file_list.count(".meta") == 1 and file_list.count(".index") == 1: + return True + return False + + +def _parse_ckpt_bn_input(graph_def): + """Parse ckpt batch norm inputs to match correct moving mean and variance. + + Args: + graph_def (graph_def): original graph_def + Returns: + graph_def: well linked graph_def + """ + for node in graph_def.node: + if node.op == "FusedBatchNorm": + moving_mean_op_name = node.input[3] + moving_var_op_name = node.input[4] + moving_mean_op = _get_nodes_from_name(moving_mean_op_name, graph_def)[0] + moving_var_op = _get_nodes_from_name(moving_var_op_name, graph_def)[0] + + if moving_mean_op.op == "Const": + name_part = moving_mean_op_name.rsplit("/", 1)[0] + real_moving_mean_op_name = name_part + "/moving_mean" + if len(_get_nodes_from_name(real_moving_mean_op_name, graph_def)) > 0: + # replace the real moving mean op name + node.input[3] = real_moving_mean_op_name + + if moving_var_op.op == "Const": + name_part = moving_var_op_name.rsplit("/", 1)[0] + real_moving_var_op_name = name_part + "/moving_variance" + if len(_get_nodes_from_name(real_moving_var_op_name, graph_def)) > 0: + # replace the real moving mean op name + node.input[4] = real_moving_var_op_name + + return graph_def + + +def _get_nodes_from_name(node_name, graph_def): + """Get nodes from graph_def using node name. + + Args: + graph_def (graph_def): graph_def + node_name (str): node name + + Returns: + node (NodeDef): graph node + """ + return [node for node in graph_def.node if node.name == node_name] + + +def is_saved_model_format(model_path): + """Check the model_path format is saved_model or not. + + Args: + model_path (string): the model folder path + + Returns: + bool: return True if the model_path contains saved_model format else False. + """ + file_list = [os.path.splitext(i)[-1] for i in os.listdir(model_path)] + # TF 2.11.0 added a new fingerprint.pb to the SavedModel directory. + return bool(file_list.count(".pb") in [1, 2, 3] and ("variables") in os.listdir(model_path)) + + +def get_tensor_by_name(graph, name, try_cnt=3): + """Get the tensor by name. + + Considering the 'import' scope when model may be imported more then once, + handle naming format like both name:0 and name. + + Args: + graph (tf.compat.v1.GraphDef): the model to get name from + name (string): tensor of tensor_name:0 or tensor_name without suffixes + try_cnt: the times to add 'import/' to find tensor + + Returns: + tensor: tensor got by name. + """ + if name.find(":") == -1: + name = name + ":0" + for _ in range(try_cnt): + try: + return graph.get_tensor_by_name(name) + except BaseException: + name = "import/" + name + raise ValueError("can not find tensor by name") + + +def iterator_sess_run(sess, iter_op, feed_dict, output_tensor, iteration=-1, measurer=None): + """Run the graph that have iterator integrated in the graph. + + Args: + sess (tf.compat.v1.Session): the model sess to run the graph + iter_op (Operator): the MakeIterator op + feed_dict(dict): the feeds to initialize a new iterator + output_tensor(list): the output tensors + iteration(int): iterations to run, when -1 set, run to end of iterator + + Returns: + preds: the results of the predictions + """ + sess.run(iter_op, feed_dict) + preds = [] + idx = 0 + while idx + 1 != iteration: + try: + if measurer: + measurer.start() + prediction = sess.run(output_tensor) + measurer.end() + else: + prediction = sess.run(output_tensor) + preds.append(prediction) + idx += 1 + except tf.errors.OutOfRangeError: + break + + preds = collate_tf_preds(preds) + return preds + + +def collate_tf_preds(results): + """Collate the prediction results.""" + batch = results[0] + if isinstance(batch, list): + results = zip(*results) + collate_results = [] + for output in results: + if isinstance(output[0], np.ndarray): + collate_results.append(np.concatenate(output)) + elif np.isscalar(output[0]): + collate_results.extend(output) + elif isinstance(batch, np.ndarray): + collate_results = np.concatenate(results) + + return collate_results + + +def get_input_output_node_names(graph_def): + """Get the input node name and output node name of the graph_def.""" + g = GraphAnalyzer() + g.graph = graph_def + g.parse_graph() + return g.get_graph_input_output() + + +def fix_ref_type_of_graph_def(graph_def): + """Fix ref type of the graph_def.""" + # according to https://github.com/onnx/tensorflow-onnx/issues/77 + for node in graph_def.node: + if node.op == "RefSwitch": + node.op = "Switch" + for index in range(len(node.input)): + if "moving_" in node.input[index]: + node.input[index] = node.input[index] + "/read" + elif node.op == "AssignSub": + node.op = "Sub" + if "use_locking" in node.attr: + del node.attr["use_locking"] + elif node.op == "AssignAdd": + node.op = "Add" + if "use_locking" in node.attr: + del node.attr["use_locking"] + elif node.op == "Assign": + node.op = "Identity" + if "use_locking" in node.attr: + del node.attr["use_locking"] + if "validate_shape" in node.attr: + del node.attr["validate_shape"] + if len(node.input) == 2: + # input0: ref: Should be from a Variable node. May be uninitialized. + # input1: value: The value to be assigned to the variable. + node.input[0] = node.input[1] + del node.input[1] + return graph_def + + +def strip_unused_nodes(graph_def, input_node_names, output_node_names): + """Strip unused nodes of the graph_def. + + The strip_unused_nodes pass is from tensorflow/python/tools/strip_unused_lib.py + of official tensorflow r1.15 branch + """ + cur_graph = GraphAnalyzer() + cur_graph.graph = graph_def + graph_info = cur_graph.parse_graph() + type_attr = {"Sub": "T", "RealDiv": "T", "Identity": "T"} + # this op should not be stripped for table initialization + if "init_all_tables" in graph_info.keys(): + output_node_names.append("init_all_tables") + not_found = {name for name in input_node_names} + for node_name in list(graph_info.keys()): + if node_name in not_found: + not_found.remove(node_name) + node = graph_info[node_name].node + # skip the conversion to Placeholder that with type list + if "component_types" in node.attr: + continue + original_output = graph_info[node_name].outputs + placeholder_node = node_def_pb2.NodeDef() + placeholder_node.op = "Placeholder" + placeholder_node.name = node.name + + if "dtype" in node.attr: + placeholder_node.attr["dtype"].CopyFrom(attr_value_pb2.AttrValue(type=node.attr["dtype"].type)) + elif node.op in type_attr.keys(): + placeholder_node.attr["dtype"].CopyFrom( + attr_value_pb2.AttrValue(type=node.attr[type_attr[node.op]].type) + ) + else: + raise KeyError("%s op's type attribute is not found," "you should add it to type_attr dict" % node.op) + if "_output_shapes" in node.attr: + placeholder_node.attr["_output_shapes"].CopyFrom(node.attr["_output_shapes"]) + if "shape" in node.attr: + placeholder_node.attr["shape"].CopyFrom(node.attr["shape"]) + + cur_graph.remove_node(node_name) + + cur_graph.replace_const_node(placeholder_node, [node_name], original_output) + + return tf.compat.v1.graph_util.extract_sub_graph(cur_graph.dump_graph(), output_node_names) + + +def strip_equivalent_nodes(graph_def, output_node_names): + """Strip nodes with the same input and attr.""" + stripped_graph = GraphAnalyzer() + stripped_graph.graph = graph_def + stripped_graph_info = stripped_graph.parse_graph() + + def is_equivalent_input(input_tensor_list_1, input_tensor_list_2): + if len(input_tensor_list_1) != len(input_tensor_list_2): + return False + const_num = 0 + for input_tensor_1, input_tensor_2 in zip(input_tensor_list_1, input_tensor_list_2): + input_node_1 = stripped_graph_info[GraphRewriterHelper.node_name_from_input(input_tensor_1)].node + input_node_2 = stripped_graph_info[GraphRewriterHelper.node_name_from_input(input_tensor_2)].node + if input_node_1.op in ["Const", "HostConst"] and input_node_2.op in ["Const", "HostConst"]: + if input_node_1.attr != input_node_2.attr: + return False + const_num += 1 + elif input_tensor_1 != input_tensor_2: + return False + if const_num == len(input_tensor_list_1): + return False + return True + + nodes_to_remove = [] + replaced_nodes_type = {} + stripped_graph_node_names = list(stripped_graph_info.keys()) + len_nodes = len(stripped_graph_node_names) + for idx_1 in range(len_nodes - 1): + node_name_1 = stripped_graph_node_names[idx_1] + node_1 = stripped_graph_info[node_name_1].node + if node_1.op in ["Const", "HostConst", "MatMul", "TensorArrayV3"] or node_name_1 in nodes_to_remove: + continue + for idx_2 in range(idx_1 + 1, len_nodes): + node_name_2 = stripped_graph_node_names[idx_2] + node_2 = stripped_graph_info[node_name_2].node + if ( + node_1.op == node_2.op + and node_name_1 != node_name_2 + and node_2 not in nodes_to_remove + and node_1.input + and is_equivalent_input(node_1.input, node_2.input) + and node_1.attr == node_2.attr + ): + for ouput_node_name in stripped_graph_info[node_name_2].outputs: + output_node = stripped_graph_info[ouput_node_name].node + for idx_output_node_input, output_node_input_name in enumerate(output_node.input): + if GraphRewriterHelper.node_name_from_input(output_node_input_name) == node_name_2: + new_input = output_node_input_name.replace(node_name_2, node_name_1) + output_node.input[idx_output_node_input] = new_input + logger.debug( + "Replacing {} node '{}' with equivalent node '{}': " + "set {} node '{}'.input[{}] = '{}'".format( + node_1.op, + node_name_2, + node_name_1, + output_node.op, + output_node.name, + idx_output_node_input, + new_input, + ) + ) + replaced_nodes_type[node_1.op] = replaced_nodes_type.get(node_1.op, 0) + 1 + nodes_to_remove.append(node_name_2) + for node_to_remove in nodes_to_remove: + stripped_graph.remove_node(node_to_remove) + return ( + tf.compat.v1.graph_util.extract_sub_graph( + stripped_graph.dump_graph(), list(set(stripped_graph_node_names).intersection(output_node_names)) + ), + replaced_nodes_type, + ) + + +# THIS API IS TO BE DEPRECATED! +def get_graph_def(model, outputs=[], auto_input_output=False): + """Get the model's graph_def.""" + from neural_compressor.tensorflow.utils import BaseModel, Model + + if not isinstance(model, BaseModel): + model = Model(model) + model.output_tensor_names = outputs + return model.graph_def + + +def get_model_input_shape(model): + """Get the input shape of the input model.""" + for node in model.graph_def.node: + if node.op == "Placeholder": + _shape = list(tf.compat.v1.TensorShape(node.attr["shape"].shape)) + if tf.__version__ < "2.0.0": + _shape = [item.value for item in _shape] + if len(_shape) > 1 and isinstance(_shape[0], int): + return _shape[0] + return 1 + + +def get_tensor_val_from_graph_node(graph_node_name_mapping, node_name): + """Get the tensor value for given node name. + + Args: + graph_node_name_mapping: key: node name, val: node + node_name: query node + + Returns: + tensor_val: numpy array + """ + from tensorflow.python.framework import tensor_util + + node = graph_node_name_mapping[node_name] + node_tensor = node.attr["value"].tensor + tensor_val = tensor_util.MakeNdarray(node_tensor) + return tensor_val + + +def int8_node_name_reverse(node): + """Reverse int8 node name.""" + int8_postfix = "_eightbit" + node_name = node.name + if "Quantized" in node.op: + index_postfix = node_name.find(int8_postfix) + if index_postfix != -1: + node_name = node_name[:index_postfix] + return node_name + + +def tf_diagnosis_helper(fp32_model, quan_model, tune_cfg, save_path): + """Tensorflow diagnosis helper function.""" + from ...utils.utility import dump_data_to_local + + fp32_node_mapping = {} + qnode_mapping = {} + for node in fp32_model.graph_def.node: + fp32_node_mapping[node.name] = node + for node in quan_model.graph_def.node: + qnode_mapping[node.name] = node + supported_op_lst = set(["Conv2D", "MatMul", "ConcatV2", "MaxPool", "AvgPool", "DepthwiseConv2dNative"]) + fp32_node_lst = set() + for node in fp32_model.graph_def.node: + if node.op in supported_op_lst: + fp32_node_lst.add(node.name) + int8_node_lst = set() + bf16_node_lst = set() + for node in quan_model.graph_def.node: + node_name = node.name + node_name = int8_node_name_reverse(node) + if "Quantized" in node.op: + int8_node_lst.add(node_name) + elif node.attr["value"].tensor.dtype == tf.dtypes.bfloat16.as_datatype_enum: # pragma: no cover + bf16_node_lst.add(node.name) + else: + continue + inspect_node_lst = fp32_node_lst.intersection(bf16_node_lst.union(int8_node_lst)) + activation_min_max, updated_cfg = _parse_config(quan_model.q_config, tune_cfg, inspect_node_lst) + dump_data_to_local(activation_min_max, save_path, "activation_min_max.pkl") + dump_data_to_local(updated_cfg, save_path, "cfg.pkl") + + return inspect_node_lst, updated_cfg + + +def _parse_config(q_config, cfg, op_list): + """Parse q_config and get dequantize min max value.""" + activation_min_max = {} + if "__requant_min_max" in q_config: + for node_name, val in q_config["__requant_min_max"].items(): + node_name = node_name.split("_eightbit_requant_range")[0] + if node_name in op_list: + activation_min_max[node_name] = {"min": val[0], "max": val[1]} + updated_cfg = {"op": {}} + for op_name_and_type in cfg["op"].keys(): + if op_name_and_type[0] in op_list: + updated_cfg["op"][op_name_and_type] = cfg["op"][op_name_and_type] + return activation_min_max, updated_cfg + + +def generate_feed_dict(input_tensor, inputs): + """Generate feed dict helper function.""" + if len(input_tensor) == 1: + feed_dict = {} + if isinstance(inputs, dict) or isinstance(inputs, OrderedDict) or isinstance(inputs, UserDict): + for name in inputs: + for tensor in input_tensor: + pos = tensor.name.rfind(":") + t_name = tensor.name if pos < 0 else tensor.name[:pos] + if name == t_name: + feed_dict[tensor] = inputs[name] + break + else: + feed_dict = {input_tensor[0]: inputs} # get raw tensor using index [0] + else: + assert len(input_tensor) == len(inputs), "inputs len must equal with input_tensor" + feed_dict = {} + if isinstance(inputs, dict) or isinstance(inputs, OrderedDict) or isinstance(inputs, UserDict): + for name in inputs: + for tensor in input_tensor: + pos = tensor.name.rfind(":") + t_name = tensor.name if pos < 0 else tensor.name[:pos] + if name in [tensor.name, t_name]: + feed_dict[tensor] = inputs[name] + break + else: + # sometimes the input_tensor is not the same order with inputs + # we should check and pair them + def check_shape(tensor, data): + # scalar or 1 dim default True + if tensor.shape is None or len(tensor.shape.dims) == 1 or not hasattr(data, "shape"): + return True + tensor_shape = tuple(tensor.shape) + data_shape = tuple(data.shape) + for tensor_dim, data_dim in zip(tensor_shape, data_shape): + if tensor_dim is not None and tensor_dim != data_dim: + return False + return True + + disorder_tensors = [] + disorder_inputs = [] + for idx, sort_tensor in enumerate(input_tensor): + sort_input = inputs[idx] + if check_shape(sort_tensor, sort_input): + feed_dict.update({sort_tensor: sort_input}) + else: + disorder_tensors.append(sort_tensor) + disorder_inputs.append(sort_input) + for i, dis_tensor in enumerate(disorder_tensors): + for j, dis_input in enumerate(disorder_inputs): + if check_shape(dis_tensor, dis_input): + feed_dict.update({dis_tensor: dis_input}) + break + return feed_dict + + +def get_weight_from_input_tensor(model, input_tensor_names, op_types): + """Extracts weight tensors and their associated nodes from a smooth quant node's input tensor. + + Args: + model: A TensorFlow model containing a `graph_def` attribute. + input_tensor_names: A list of input tensor names to search for weight tensors. + op_types: A list of operation types to search for when looking for weight tensors. + + Returns: + A tuple of two dictionaries: + - sq_weight_tensors: A dictionary mapping each input tensor name + to a dict of its associated weight tensors with weight name. + - sq_weights_nodes: A dictionary mapping each input tensor name + to a dict of its associated weight nodes with weight name. + """ + g_analyzer = GraphAnalyzer() + g_analyzer.graph = model.graph_def + graph_info = g_analyzer.parse_graph() + + sq_weight_tensors = {} + sq_weights_nodes = {} + + from tensorflow.python.framework import tensor_util + + for name in input_tensor_names: + # Use dict rather than list to fix the QKV/VQK misorder issue + curr_weight_tensors = {} + curr_weights_nodes = {} + next_node_names = graph_info[name].outputs + for node_name in next_node_names: + curr_node = graph_info[node_name].node + if curr_node.op not in op_types: + continue + if len(curr_node.input) >= 2: + weight_name = curr_node.input[1] + weight_node = graph_info[weight_name].node + weight_tensor = tensor_util.MakeNdarray(weight_node.attr["value"].tensor) + curr_weight_tensors[weight_name] = weight_tensor + curr_weights_nodes[weight_name] = weight_node + # {input node -> {xxx_q_proj_matmul: value1, xxx_v_proj_matmul: value2, ...}, ...} + sq_weight_tensors[name] = curr_weight_tensors + sq_weights_nodes[name] = curr_weights_nodes + return sq_weight_tensors, sq_weights_nodes + + +def apply_inlining(func): + """Apply an inlining optimization to the function's graph definition. + + Args: + func: A concrete function get from saved_model. + + Returns: + new_graph_def: The optimized graph in graph_def format. + """ + graph_def = func.graph.as_graph_def() + + # In some cases, a secondary implementation of the function (e.g. for GPU) is + # written to the "api_implements" attribute. (e.g. `tf.keras.layers.LSTM` in + # TF2 produces a CuDNN-based RNN for GPU). + # This function suppose to inline all functions calls, but "api_implements" + # prevents this from happening. Removing the attribute solves the problem. + # To learn more about "api_implements", see: + # tensorflow/core/grappler/optimizers/implementation_selector.h + for function in graph_def.library.function: + if "api_implements" in function.attr: + del function.attr["api_implements"] + + meta_graph = saver.export_meta_graph(graph_def=graph_def, graph=func.graph) + + # Clear the initializer_name for the variables collections, since they are not + # needed after saved to saved_model. + for name in ["variables", "model_variables", "trainable_variables", "local_variables"]: + raw_list = [] + for raw in meta_graph.collection_def["variables"].bytes_list.value: + variable = variable_pb2.VariableDef() + variable.ParseFromString(raw) + variable.ClearField("initializer_name") + raw_list.append(variable.SerializeToString()) + meta_graph.collection_def[name].bytes_list.value[:] = raw_list + + # Add a collection 'train_op' so that Grappler knows the outputs. + fetch_collection = meta_graph_pb2.CollectionDef() + for array in func.inputs + func.outputs: + fetch_collection.node_list.value.append(array.name) + meta_graph.collection_def["train_op"].CopyFrom(fetch_collection) + + # Initialize RewriterConfig with everything disabled except function inlining. + config = config_pb2.ConfigProto() + rewrite_options = config.graph_options.rewrite_options + rewrite_options.min_graph_nodes = -1 # do not skip small graphs + rewrite_options.optimizers.append("function") + + new_graph_def = tf_optimizer.OptimizeGraph(config, meta_graph) + + return new_graph_def + + +def construct_function_from_graph_def(func, graph_def, frozen_func=None): + """Rebuild function from graph_def. + + Args: + func: The original concrete function get from saved_model. + graph_def: The optimized graph after applying inlining optimization. + + Returns: + new_func: The reconstructed function. + """ + if frozen_func is None: + frozen_func = func + + # If a function is converted, then the TF context contains the original + # function while the converted_graph_def contains the converted function. + # Remove the original function from the TF context in this case. + for f in graph_def.library.function: + while context.context().has_function(f.signature.name): + context.context().remove_function(f.signature.name) + + captures = {c[1].name.split(":")[0]: c[0] for c in frozen_func.graph.captures} + new_func = wrap_function.function_from_graph_def( + graph_def, + [tensor.name for tensor in frozen_func.inputs], + [tensor.name for tensor in frozen_func.outputs], + captures, + ) + new_func.graph.structured_outputs = nest.pack_sequence_as( + func.graph.structured_outputs, new_func.graph.structured_outputs + ) + # new_func._function_type = func.function_type # pylint: disable=protected-access + + # Copy structured input signature from original function (used during + # serialization) + new_func.graph.structured_input_signature = func.structured_input_signature + + return new_func + + +def parse_saved_model(model, freeze=False, input_tensor_names=[], output_tensor_names=[]): + """Parse a input saved_model. + + Args: + model(string or AutoTrackable object): The input saved_model. + + Returns: + graph_def: The graph_def parsed from saved_model. + _saved_model: TF AutoTrackable object loaded from saved_model. + func: The concrete function get from saved_model. + frozen_func: The reconstructed function from inlining optimized graph. + """ + config = tf.compat.v1.ConfigProto() + config.use_per_session_threads = 1 + config.inter_op_parallelism_threads = 1 + + if isinstance(model, str): + _saved_model = load.load(model, [tag_constants.SERVING]) + else: + _saved_model = model + + func = _saved_model.signatures[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] + + if freeze: + frozen_func = convert_to_constants.convert_variables_to_constants_v2(func) + else: + inlined_graph_def = apply_inlining(func) + frozen_func = construct_function_from_graph_def(func, inlined_graph_def) + + if len(input_tensor_names) == 0: + # skip all inputs for ReadVariableOp + input_tensor_names = [i.name.split(":")[0] for i in frozen_func.inputs if "unknown" not in i.name] + if len(output_tensor_names) == 0: + output_tensor_names = [i.name.split(":")[0] for i in frozen_func.outputs] + + frozen_graph_def = frozen_func.graph.as_graph_def() + grappler_meta_graph_def = saver.export_meta_graph(graph_def=frozen_graph_def, graph=frozen_func.graph) + + # Add a collection 'train_op' so that Grappler knows the outputs. + fetch_collection = meta_graph_pb2.CollectionDef() + for array in frozen_func.inputs + frozen_func.outputs: + fetch_collection.node_list.value.append(array.name) + grappler_meta_graph_def.collection_def["train_op"].CopyFrom(fetch_collection) + + grappler_session_config = config_pb2.ConfigProto() + rewrite_options = grappler_session_config.graph_options.rewrite_options + rewrite_options.min_graph_nodes = -1 + graph_def = tf_optimizer.OptimizeGraph(grappler_session_config, grappler_meta_graph_def, graph_id=b"tf_graph") + return graph_def, _saved_model, func, frozen_func, input_tensor_names, output_tensor_names + + +def reconstruct_saved_model(graph_def, func, frozen_func, trackable, path): + """Reconstruct a saved_model. + + Args: + graph_def: The input graph_def. + func: The concrete function get from the original saved_model. + frozen_func: The reconstructed function from inlining optimized graph. + trackable: TF AutoTrackable object loaded from the original saved_model. + path: The destination path to save the reconstructed saved_model. + """ + converted_func = construct_function_from_graph_def(func, graph_def, frozen_func) + signatures = {signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: converted_func} + save.save(trackable, path, signatures, options=None) diff --git a/neural_compressor/tensorflow/utils.py b/neural_compressor/tensorflow/utils.py deleted file mode 100644 index 4497c1e9a7a..00000000000 --- a/neural_compressor/tensorflow/utils.py +++ /dev/null @@ -1,88 +0,0 @@ -# Copyright (c) 2023 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging -import time -from functools import reduce -from typing import Callable, Dict - -import numpy as np -import tensorflow as tf -from pkg_resources import parse_version - -# Dictionary to store a mapping between algorithm names and corresponding algo implementation(function) -algos_mapping: Dict[str, Callable] = {} - - -def version1_gte_version2(version1, version2): - """Check if version1 is greater than or equal to version2.""" - return parse_version(version1) > parse_version(version2) or parse_version(version1) == parse_version(version2) - - -def register_algo(name): - """Decorator function to register algorithms in the algos_mapping dictionary. - - Usage example: - @register_algo(name=example_algo) - def example_algo(model: torch.nn.Module, quant_config: RTNConfig) -> torch.nn.Module: - ... - Args: - name (str): The name under which the algorithm function will be registered. - Returns: - decorator: The decorator function to be used with algorithm functions. - """ - - def decorator(algo_func): - algos_mapping[name] = algo_func - return algo_func - - return decorator - - -def deep_get(dictionary, keys, default=None): - """Get the dot key's item in nested dict - eg person = {'person':{'name':{'first':'John'}}} - deep_get(person, "person.name.first") will output 'John'. - - Args: - dictionary (dict): The dict object to get keys - keys (dict): The deep keys - default (object): The return item if key not exists - Returns: - item: the item of the deep dot keys - """ - return reduce(lambda d, key: d.get(key, default) if isinstance(d, dict) else default, keys.split("."), dictionary) - - -def dump_elapsed_time(customized_msg=""): - """Get the elapsed time for decorated functions. - - Args: - customized_msg (string, optional): The parameter passed to decorator. Defaults to None. - """ - - def f(func): - def fi(*args, **kwargs): - start = time.time() - res = func(*args, **kwargs) - end = time.time() - logging.getLogger("neural_compressor").info( - "%s elapsed time: %s ms" - % (customized_msg if customized_msg else func.__qualname__, round((end - start) * 1000, 2)) - ) - return res - - return fi - - return f diff --git a/neural_compressor/tensorflow/utils/__init__.py b/neural_compressor/tensorflow/utils/__init__.py new file mode 100644 index 00000000000..d77ad26fc47 --- /dev/null +++ b/neural_compressor/tensorflow/utils/__init__.py @@ -0,0 +1,50 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from neural_compressor.tensorflow.utils.model import Model, framework_specific_info +from neural_compressor.tensorflow.utils.data import BaseDataLoader, DummyDataset, DummyDatasetV2 +from neural_compressor.tensorflow.utils.constants import SPR_BASE_VERSIONS, DEFAULT_SQ_ALPHA_ARGS +from neural_compressor.tensorflow.utils.model_wrappers import ( + get_tf_model_type, + BaseModel, + KerasModel, + TensorflowLLMModel, + TensorflowBaseModel, + TensorflowSavedModelModel, +) +from neural_compressor.tensorflow.utils.utility import ( + disable_random, + algos_mapping, + version1_lt_version2, + version1_gt_version2, + version1_eq_version2, + version1_gte_version2, + version1_lte_version2, + register_algo, + deep_get, + itex_installed, + dump_elapsed_time, + combine_histogram, + get_all_fp32_data, + get_tensor_histogram, + Dequantize, + dequantize_weight, + dump_data_to_local, + load_data_from_pkl, + singleton, + CpuInfo, + Statistics, + CaptureOutputToFile, + LazyImport, +) diff --git a/neural_compressor/tensorflow/utils/constants.py b/neural_compressor/tensorflow/utils/constants.py new file mode 100644 index 00000000000..70a4188c7f2 --- /dev/null +++ b/neural_compressor/tensorflow/utils/constants.py @@ -0,0 +1,31 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +SPR_BASE_VERSIONS = ( + "2.11.0202242", + "2.11.0202250", + "2.11.0202317", + "2.11.0202323", + "2.14.0202335", + "2.14.dev202335", + "2.15.0202341", +) + +DEFAULT_SQ_ALPHA_ARGS = { + "alpha_min": 0.0, + "alpha_max": 1.0, + "alpha_step": 0.1, + "shared_criterion": "mean", + "do_blockwise": False, +} diff --git a/neural_compressor/tensorflow/utils/data.py b/neural_compressor/tensorflow/utils/data.py new file mode 100644 index 00000000000..8e0f7dc8cc0 --- /dev/null +++ b/neural_compressor/tensorflow/utils/data.py @@ -0,0 +1,388 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""BaseDataloder of all dataloaders.""" + +import sys +from abc import abstractmethod + +import numpy as np + +from neural_compressor.common import logger + + +class BaseDataLoader: # pragma: no cover + """Base class for all DataLoaders. + + _generate_dataloader is needed to create a dataloader object + from the general params like batch_size and sampler. The dynamic batching is just to + generate a new dataloader by setting batch_size and last_batch. + """ + + def __init__( + self, + dataset, + batch_size=1, + last_batch="rollover", + collate_fn=None, + sampler=None, + batch_sampler=None, + num_workers=0, + pin_memory=False, + shuffle=False, + distributed=False, + ): + """Initialize BaseDataLoader. + + Args: + dataset (object): dataset from which to load the data + batch_size (int, optional): number of samples per batch. Defaults to 1. + last_batch (str, optional): whether to drop the last batch if it is incomplete. + Support ['rollover', 'discard'], rollover means False, discard means True. + Defaults to 'rollover'. + collate_fn (callable, optional): merge data with outer dimension batch size. Defaults to None. + sampler (Sampler, optional): Sampler object to sample data. Defaults to None. + batch_sampler (BatchSampler, optional): BatchSampler object to generate batch of indices. Defaults to None. + num_workers (int, optional): number of subprocesses to use for data loading. Defaults to 0. + pin_memory (bool, optional): whether to copy data into pinned memory before returning. Defaults to False. + shuffle (bool, optional): whether to shuffle data. Defaults to False. + distributed (bool, optional): whether the dataloader is distributed. Defaults to False. + """ + self.dataset = dataset + self.collate_fn = collate_fn + self.sampler = sampler + self.batch_sampler = batch_sampler + self.num_workers = num_workers + self.pin_memory = pin_memory + self._batch_size = batch_size + self.shuffle = shuffle + self.distributed = distributed + self.last_batch = last_batch + self.drop_last = False if last_batch == "rollover" else True + + self.dataloader = self._generate_dataloader( + self.dataset, + batch_size=batch_size, + last_batch=last_batch, + collate_fn=collate_fn, + sampler=sampler, + batch_sampler=batch_sampler, + num_workers=num_workers, + pin_memory=pin_memory, + shuffle=shuffle, + distributed=distributed, + ) + + def batch(self, batch_size, last_batch=None): + """Set batch size for dataloader. + + Args: + batch_size (int): number of samples per batch. + last_batch (str, optional): whether to drop the last batch if it is incomplete. + Support ['rollover', 'discard'], rollover means False, discard means True. + Defaults to None. + """ + self._batch_size = batch_size + if last_batch is not None: + self.last_batch = last_batch + self.dataloader = self._generate_dataloader( + self.dataset, + batch_size, + self.last_batch, + self.collate_fn, + self.sampler, + self.batch_sampler, + self.num_workers, + self.pin_memory, + self.shuffle, + self.distributed, + ) + + @property + def batch_size(self): + """Get dataloader's batch_size. + + Returns: + int: batch_size + """ + return self._batch_size + + def __iter__(self): + """Yield data in iterative order. + + Returns: + iterator: iterator for dataloder + """ + return iter(self.dataloader) + + @abstractmethod + def _generate_dataloader( + self, + dataset, + batch_size, + last_batch, + collate_fn, + sampler, + batch_sampler, + num_workers, + pin_memory, + shuffle, + distributed, + ): + raise NotImplementedError + + +class DummyDataset: # pragma: no cover + """Dataset used for dummy data generation. + + This Dataset is to construct a dataset from a specific shape. + The value range is calculated from: low * stand_normal(0, 1) + high. + (TODO) construct dummy data from real dataset or iteration of data. + """ + + def __init__(self, shape, low=-128.0, high=127.0, dtype="float32", label=True, transform=None, filter=None): + """Initialize `DummyDataset` class. + + Args: + shape (list or tuple): Support create multi shape tensors, use list of tuples + for each tuple in the list, will create a such size tensor. + low (list or float, default=-128.): Low out the tensor value range from [0, 1] + to [0, low] or [low, 0] if low < 0, if float, will implement all tensors with same low value. + high (list or float, default=127.): High the tensor value by add all tensor element + value high. If list, length of list should be same with shape list. + dtype (list or str, default='float32'): Support multi tensor dtype setting. + If list, length of list should be same with shape list. If str, all tensors will + use same dtype. dtype supports 'float32', 'float16', 'uint8', 'int8', 'int32', 'int64', 'bool'. + label (bool, default=True): Whether to return 0 as label. + transform (transform object, default=None): Dummy dataset does not need transform. + If transform is not None, it will ignore it. + filter (Filter objects, default=None): Filter out examples according to specific conditions. + """ + dtype_map = { + "float32": np.float32, + "float16": np.float16, + "uint8": np.uint8, + "int8": np.int8, + "int32": np.int32, + "int64": np.int64, + "bool": bool, + "string": str, + } + + np.random.seed(9527) + self.transform = transform + self.label = label + if len(shape) == 0: + logger.info("No data in the dummy dataset.") + elif isinstance(shape, list): + # list tensor should same first dimension n + n = shape[0][0] + assert all( + isinstance(elem, tuple) and elem[0] == n for elem in shape + ), "each tensor shape should be tuple and same first dimension" + + if isinstance(low, list): + assert len(low) == len(shape) and all( + isinstance(elem, float) for elem in low + ), "low list should have same length with shape with element data type float" + else: + low = (low * np.ones(len(shape))).astype(float) + + if isinstance(high, list): + assert len(high) == len(shape) and all( + isinstance(elem, float) for elem in high + ), "high list should have same length with shape with element data type float" + else: + high = (high * np.ones(len(shape))).astype(float) + + if isinstance(dtype, list): + assert len(dtype) == len(shape) and all( + elem in dtype_map.keys() for elem in dtype + ), "high list should have same length with shape with element data type float" + else: + dtype = [dtype for i in range(0, len(shape))] + + elif isinstance(shape, tuple): + shape = [shape] + if isinstance(low, float): + low = [low] + else: + assert ( + isinstance(low, list) and len(low) == 1 and isinstance(low[0], float) + ), "low should be float or list of float with length 1" + + if isinstance(high, float): + high = [high] + else: + assert ( + isinstance(high, list) and len(high) == 1 and isinstance(high[0], float) + ), "high should be float or list of float with length 1" + + if isinstance(dtype, str): + assert dtype in dtype_map.keys(), "dtype only support {}".format(dtype_map.keys()) + dtype = [dtype] + else: + assert ( + isinstance(dtype, list) and len(dtype) == 1 and dtype[0] in dtype_map.keys() + ), "dtype should be str or list of str in supported dtypes" + + self.dataset = [] + for idx in range(0, len(shape)): + tensor = np.random.uniform(low=low[idx], high=high[idx], size=shape[idx]) + tensor = tensor.astype(dtype_map[dtype[idx]]) + self.dataset.append(tensor) + + if len(self.dataset) == 1: + self.dataset = self.dataset[0] + else: + self.dataset = [elem for elem in zip(*self.dataset)] + + def __len__(self): + """Return the length of dataset.""" + return len(self.dataset) + + def __getitem__(self, index): + """Return the item of dataset according to the given index.""" + sample = self.dataset[index] + if self.transform is not None: + logger.warning("Dummy dataset does not need transform.") + + if self.label: + return sample, 0 + else: + return sample + + +class DummyDatasetV2: # pragma: no cover + """Dataset used for dummy_v2 data generation. + + This Dataset is to construct a dataset from a input shape and label shape. + The value range is calculated from: low * stand_normal(0, 1) + high. + """ + + def __init__( + self, input_shape, label_shape=None, low=-128.0, high=127.0, dtype="float32", transform=None, filter=None + ): + """Initialize `DummyDataset` class. + + Args: + sample_size (int): Total size of the dummy samples. + input_shape (list or tuple): Create single or multi input tensors, + tuple represent the sample shape of the dataset, e.g. an image size should be + represented as (224, 224, 3), list contains multiple tuple and represent multi input tensors. + label_shape (list or tuple): Create single or multi label tensors, + tuple represent the label shape of the dataset, e.g. an label size should be + represented as (1, ), list contains multiple tuple and represent multi label tensors. + low (list or float, default=-128.): Low out the tensor value range from [0, 1] + to [0, low] or [low, 0] if low < 0. If float, will implement all tensors with same low value. + high (list or float, default=127.): High the tensor value by add all tensor element value high. + If list, length of list should be same with shape list. + dtype (list or str, default='float32'): Support multi tensor dtype setting. + If list, length of list should be same with shape list. + If str, all tensors will use same dtype. + dtype supports 'float32', 'float16', 'uint8', 'int8','int32', 'int64', 'bool'. + transform (transform object, default=None): dummy_v2 dataset does not need transform. + If transform is not None, it will ignore it. + filter (Filter objects, default=None): Filter out examples according to specific conditions. + """ + self.dtype_map = { + "float32": np.float32, + "float16": np.float16, + "uint8": np.uint8, + "int8": np.int8, + "int32": np.int32, + "int64": np.int64, + "bool": bool, + } + + np.random.seed(9527) + self.transform = transform + self.input_shape = input_shape + self.label_shape = label_shape + self.low = low + self.high = high + self.dtype = dtype + + if label_shape is None: + self.label_dim = 0 + elif isinstance(label_shape, tuple): + self.label_dim = 1 + else: + self.label_dim = len(label_shape) + + self.input_dim = 1 if isinstance(input_shape, tuple) else len(input_shape) + self.total_dim = self.input_dim + self.label_dim + + if isinstance(high, list): + assert len(high) == self.total_dim and all( + isinstance(elem, float) for elem in high + ), "high value list length should same with label dim + input_dim" + else: + self.high = (high * np.ones(self.total_dim)).astype(np.float32) + + if isinstance(low, list): + assert len(low) == self.total_dim and all( + isinstance(elem, float) for elem in low + ), "low value list length should same with label dim + input_dim" + else: + self.low = (low * np.ones(self.total_dim)).astype(np.float32) + + if isinstance(dtype, list): + assert len(dtype) == self.total_dim and all( + elem in self.dtype_map.keys() for elem in dtype + ), "dtype list length should same with label dim + input_dim" + else: + self.dtype = [self.dtype for i in range(0, self.total_dim)] + + if isinstance(input_shape, tuple): + self.input_shape = [input_shape] + + if isinstance(label_shape, tuple): + self.label_shape = [label_shape] + + def __iter__(self): + """Yield data in iterative order.""" + while True: + input_data = [] + for idx in range(0, self.input_dim): + tensor = np.random.uniform(low=self.low[idx], high=self.high[idx], size=self.input_shape[idx]) + tensor = tensor.astype(self.dtype_map[self.dtype[idx]]) + input_data.append(tensor) + + label = [] + for idx in range(0, self.label_dim): + shift_idx = self.input_dim + idx + tensor = np.random.uniform( + low=self.low[shift_idx], high=self.high[shift_idx], size=self.label_shape[idx] + ) + tensor = tensor.astype(self.dtype_map[self.dtype[shift_idx]]) + label.append(tensor) + + if len(input_data) == 1: + input_data = input_data[0] + + if len(label) == 1: + label = label[0] + + if len(label) > 0: + yield input_data, label + else: + yield input_data + + def __len__(self): + """Return the length of dataset.""" + return sys.maxsize diff --git a/neural_compressor/tensorflow/utils/model.py b/neural_compressor/tensorflow/utils/model.py new file mode 100644 index 00000000000..b05032b53bc --- /dev/null +++ b/neural_compressor/tensorflow/utils/model.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from neural_compressor.common.utils import DEFAULT_WORKSPACE +from neural_compressor.tensorflow.utils.model_wrappers import BaseModel, KerasModel, TensorflowModel, get_tf_model_type + +framework_specific_info = { + "device": "cpu", + "backend": "default", + "approach": "post_training_static_quant", + "random_seed": 1978, + "workspace_path": DEFAULT_WORKSPACE, + "format": "default", +} + + +class Model(object): + """A wrapper to construct a Neural Compressor TF Model.""" + + def __new__(cls, root, **kwargs): + """Create a new instance object of Model. + + Args: + root (object): raw model format. For Tensorflow model, could be path to frozen pb file, + path to ckpt or savedmodel folder, loaded estimator/graph_def/graph/keras model object. + + Returns: + BaseModel: neural_compressor built-in model + """ + from neural_compressor.tensorflow.utils import itex_installed + + if isinstance(root, BaseModel): + return root + + if kwargs.get("approach", None) == "quant_aware_training": + model_type = "keras_qat" + elif "modelType" in kwargs: + model_type = kwargs["modelType"] + else: + model_type = get_tf_model_type(root) + + if model_type == "keras" and not itex_installed(): + model_type = "saved_model" + + model = TensorflowModel(model_type, root, **kwargs) + conf = kwargs.pop("conf", "NA") + cls.set_framework_info(conf, model) + + return model + + @staticmethod + def set_framework_info(conf, model): + if conf == "NA": + return + framework = "keras" if isinstance(model, KerasModel) else "tensorflow" + + if conf.device: + framework_specific_info["device"] = conf.device + if conf.approach: + framework_specific_info["approach"] = conf.approach + if conf.random_seed: + framework_specific_info["random_seed"] = conf.random_seed + if conf.inputs: + framework_specific_info["inputs"] = conf.inputs + if conf.outputs: + framework_specific_info["outputs"] = conf.outputs + + if framework == "keras": + framework_specific_info["backend"] = "itex" + return + + from neural_compressor.tensorflow.utils import itex_installed + + if conf.performance_only: + framework_specific_info["performance_only"] = conf.performance_only + if itex_installed(): + framework_specific_info["backend"] = "itex" + if conf.workspace_path: + framework_specific_info["workspace_path"] = conf.workspace_path + if conf.recipes: + framework_specific_info["recipes"] = conf.recipes + + framework_specific_info["use_bf16"] = conf.use_bf16 if conf.use_bf16 else False + + for item in ["scale_propagation_max_pooling", "scale_propagation_concat"]: + if framework_specific_info["recipes"] and item not in framework_specific_info["recipes"]: + framework_specific_info["recipes"].update({item: True}) diff --git a/neural_compressor/tensorflow/utils/model_wrappers.py b/neural_compressor/tensorflow/utils/model_wrappers.py new file mode 100644 index 00000000000..8781bd2aa8b --- /dev/null +++ b/neural_compressor/tensorflow/utils/model_wrappers.py @@ -0,0 +1,1589 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Class for Tensorflow model.""" + +import copy +import datetime +import importlib +import json +import os +import shutil +import sys +import tempfile +import time +from abc import abstractmethod + +import numpy as np +import tensorflow as tf + +from neural_compressor.common import logger +from neural_compressor.common.utils import DEFAULT_WORKSPACE +from neural_compressor.tensorflow.utils.utility import version1_lt_version2 + +tensor_to_node = lambda s: list(set([x.split(":")[0] for x in s])) + + +def get_tf_model_type(model): + try: + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = "-1" + model_type = get_model_type(model) + except: + os.environ.pop("CUDA_DEVICE_ORDER") + os.environ.pop("CUDA_VISIBLE_DEVICES") + raise TypeError( + "Tensorflow model format is not correctly detected. This could be" + + "caused by unsupported model or inappropriate framework installation." + ) + else: + return model_type + + +def get_model_type(model): + """Get Tensorflow mode type. + + Args: + model (string or model object): model path or model object. + + Returns: + string: model type + """ + from neural_compressor.tensorflow.quantization.utils.utility import is_ckpt_format, is_saved_model_format + + if isinstance(model, str): + model = os.path.abspath(os.path.expanduser(model)) + if ( + (model.endswith(".h5") and os.path.isfile(model)) + or is_saved_model_format(os.path.dirname(model)) + or (os.path.isdir(model) and is_saved_model_format(model)) + ): + if version1_lt_version2(tf.version.VERSION, "2.10.0"): # pragma: no cover + logger.warning("keras model running on tensorflow 2.10.0 and" " lower not support intel ITEX.") + try: + model = tf.keras.models.load_model(model) + if isinstance(model, tf.keras.Model) and hasattr(model, "to_json"): + return "keras" + return "saved_model" + except: + pass + if isinstance(model, tf.keras.Model) and hasattr(model, "to_json"): + if json.loads(model.to_json())["class_name"] in ["Sequential", "Functional"]: + # Keras adaptor only support Sequential or Functional model + return "keras" + else: + # otherwise, the backend will fallback to tensorflow_itex + return "AutoTrackable" + if isinstance(model, tf.Graph): + return "graph" + elif isinstance(model, tf.compat.v1.GraphDef): + return "graph_def" + elif isinstance(model, tf.compat.v1.estimator.Estimator): + return "estimator" + elif isinstance(model, str): + model = os.path.abspath(os.path.expanduser(model)) + if model.endswith(".pb") and os.path.isfile(model): + if is_saved_model_format(os.path.dirname(model)): + return "saved_model" + else: + return "frozen_pb" + elif model.endswith(".ckpt") and os.path.isfile(model): + return "slim" + elif os.path.isdir(model): + if is_ckpt_format(model): + return "checkpoint" + elif is_saved_model_format(model): + return "saved_model" + elif os.path.isfile(model + ".pb"): + return "frozen_pb" + + raise ValueError("model {} has not recognized model type....".format(model)) + + +def validate_graph_node(graph_def, node_names): + """Validate nodes exist in the graph_def. + + Args: + graph_def (tf.compat.v1.GraphDef): tf.compat.v1.GraphDef object. + node_names (list of string): node names to be validated. + """ + if len(node_names) == 0: + return False + all_node_name = [node.name for node in graph_def.node] + for user_name in node_names: + if user_name not in all_node_name: + logger.warning(str("Node name {} specified in yaml doesn't exist in the model.").format(user_name)) + return False + return True + + +def validate_and_inference_input_output(graph_def, input_tensor_names, output_tensor_names): + """Validate and inference the input and output tensor names of graph_def. + + Args: + graph_def (tf.compat.v1.GraphDef): tf.compat.v1.GraphDef object. + input_tensor_names (list of string): input_tensor_names of graph_def. + output_tensor_names (list of string): output_tensor_names of graph_def. + + Returns: + input_tensor_names (list of string): validated input_tensor_names. + output_tensor_names (list of string): validated output_tensor_names. + """ + from neural_compressor.tensorflow.quantization.utils.utility import get_input_output_node_names + + temp_output_tensor_names = [] + if validate_graph_node(graph_def, tensor_to_node(input_tensor_names)): + input_tensor_names = input_tensor_names + else: + input_tensor_names, temp_output_tensor_names = get_input_output_node_names(graph_def) + + if validate_graph_node(graph_def, tensor_to_node(output_tensor_names)): + output_tensor_names = output_tensor_names + elif temp_output_tensor_names: + output_tensor_names = temp_output_tensor_names + else: + _, output_tensor_names = get_input_output_node_names(graph_def) + + return input_tensor_names, output_tensor_names + + +def graph_session(model, input_tensor_names, output_tensor_names, **kwargs): + """Helper to build session with tf.compat.v1.Graph. + + Args: + model (tf.compat.v1.Graph): tf.compat.v1.Graph object. + input_tensor_names (list of string): input_tensor_names of model. + output_tensor_names (list of string): output_tensor_names of model. + + Returns: + sess (tf.compat.v1.Session): tf.compat.v1.Session object. + input_tensor_names (list of string): validated input_tensor_names. + output_tensor_names (list of string): validated output_tensor_names. + """ + config = tf.compat.v1.ConfigProto() + config.use_per_session_threads = 1 + config.inter_op_parallelism_threads = 1 + sess = tf.compat.v1.Session(graph=model, config=config) + + input_tensor_names, output_tensor_names = validate_and_inference_input_output( + model.as_graph_def(), input_tensor_names, output_tensor_names + ) + + return sess, input_tensor_names, output_tensor_names + + +def graph_def_session(model, input_tensor_names, output_tensor_names, **kwargs): + """Build session with tf.compat.v1.GraphDef. + + Args: + model (tf.compat.v1.GraphDef): tf.compat.v1.GraphDef object. + input_tensor_names (list of string): input_tensor_names of model. + output_tensor_names (list of string): output_tensor_names of model. + + Returns: + sess (tf.compat.v1.Session): tf.compat.v1.Session object + input_tensor_names (list of string): validated input_tensor_names + output_tensor_names (list of string): validated output_tensor_names + """ + device = kwargs.get("device") + graph = tf.Graph() + if version1_lt_version2(tf.version.VERSION, "2.0.0"): # pragma: no cover + from tensorflow._api.v1.config import experimental + + list_physical_devices = experimental.list_physical_devices + else: + list_physical_devices = tf.config.list_physical_devices + + try: + with graph.as_default(): + tf.import_graph_def(model, name="") + except: + input_tensor_names, output_tensor_names = validate_and_inference_input_output( + model, input_tensor_names, output_tensor_names + ) + from neural_compressor.tensorflow.quantization.utils.utility import ( + fix_ref_type_of_graph_def, + strip_unused_nodes, + ) + + model = fix_ref_type_of_graph_def(model) + input_node_names = tensor_to_node(input_tensor_names) + output_node_names = tensor_to_node(output_tensor_names) + model = strip_unused_nodes(model, input_node_names, output_node_names) + with graph.as_default(): + tf.import_graph_def(model, name="") + + return graph_session(graph, input_tensor_names, output_tensor_names, **kwargs) + + +def frozen_pb_session(model, input_tensor_names, output_tensor_names, **kwargs): + """Build session with frozen pb. + + Args: + model (string): model path. + input_tensor_names (list of string): input_tensor_names of model. + output_tensor_names (list of string): output_tensor_names of model. + + Returns: + sess (tf.compat.v1.Session): tf.compat.v1.Session object. + input_tensor_names (list of string): validated input_tensor_names. + output_tensor_names (list of string): validated output_tensor_names. + """ + graph_def = tf.compat.v1.GraphDef() + model = model if model.endswith(".pb") else model + ".pb" + with open(model, "rb") as f: + graph_def.ParseFromString(f.read()) + return graph_def_session(graph_def, input_tensor_names, output_tensor_names, **kwargs) + + +def _contains_function_with_implements_attr(saved_model_proto): + meta_graph = saved_model_proto.meta_graphs[0] + for function in meta_graph.graph_def.library.function: + if function.attr.get("_implements", None) or function.attr.get("api_implements", None): + return True + return False + + +def load_saved_model(model, saved_model_tags, input_tensor_names, output_tensor_names): # pragma: no cover + """Load graph_def from saved model with the default serving signature key. + + Args: + model: Directory of the SavedModel. + saved_model_tags: Set of tags identifying the MetaGraphDef within the + SavedModel to analyze. + input_tensor_names (list of string): input_tensor_names of model. + output_tensor_names (list of string): output_tensor_names of model. + + Returns: + graph_def: The loaded GraphDef. + input_tensors: List of input tensors. + output_tensors: List of output tensors. + """ + config = tf.compat.v1.ConfigProto() + config.use_per_session_threads = 1 + config.inter_op_parallelism_threads = 1 + if not os.listdir(os.path.join(model, "variables")): + sess = tf.compat.v1.Session(graph=tf.Graph(), config=config) + loader = tf.compat.v1.saved_model.loader.load(sess, ["serve"], model) + if len(input_tensor_names) == 0: + input_tensor_names = [i.name for _, i in loader.signature_def["serving_default"].inputs.items()] + else: + assert validate_graph_node( + sess.graph.as_graph_def(), tensor_to_node(input_tensor_names) + ), "tensor names {} not in the graph".format(input_tensor_names) + + if len(output_tensor_names) == 0: + output_tensor_names = [i.name for _, i in loader.signature_def["serving_default"].outputs.items()] + else: + assert validate_graph_node( + sess.graph.as_graph_def(), tensor_to_node(output_tensor_names) + ), "tensor names {} not in the graph".format(output_tensor_names) + + return sess.graph.as_graph_def(), input_tensor_names, output_tensor_names + else: + from tensorflow.core.protobuf import config_pb2, meta_graph_pb2 + from tensorflow.python.eager import context + from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2 + from tensorflow.python.grappler import tf_optimizer + from tensorflow.python.saved_model import load, signature_constants, tag_constants + from tensorflow.python.training import saver + + _saved_model = load.load(model, [tag_constants.SERVING]) + func = _saved_model.signatures[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] + frozen_func = convert_variables_to_constants_v2(func) + grappler_meta_graph_def = saver.export_meta_graph( + graph_def=frozen_func.graph.as_graph_def(), graph=frozen_func.graph + ) + if len(input_tensor_names) == 0: + input_tensor_names = [i.name.split(":")[0] for i in frozen_func.inputs] + if len(output_tensor_names) == 0: + output_tensor_names = [i.name.split(":")[0] for i in frozen_func.outputs] + # Add a collection 'train_op' so that Grappler knows the outputs. + fetch_collection = meta_graph_pb2.CollectionDef() + for array in frozen_func.inputs + frozen_func.outputs: + fetch_collection.node_list.value.append(array.name) + grappler_meta_graph_def.collection_def["train_op"].CopyFrom(fetch_collection) + grappler_session_config = config_pb2.ConfigProto() + rewrite_options = grappler_session_config.graph_options.rewrite_options + rewrite_options.min_graph_nodes = -1 + opt = tf_optimizer.OptimizeGraph(grappler_session_config, grappler_meta_graph_def, graph_id=b"tf_graph") + return opt, input_tensor_names, output_tensor_names + + +def _get_graph_from_saved_model_v2(saved_model_dir, input_tensor_names, output_tensor_names): + from tensorflow.python.saved_model import signature_constants, tag_constants + + from neural_compressor.tensorflow.quantization.utils.utility import parse_saved_model + + saved_model_exported_names = [signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] + saved_model_tags = set([tag_constants.SERVING]) + try: + graph_def, _saved_model, _, _, input_names, output_names = parse_saved_model( + saved_model_dir, True, input_tensor_names, output_tensor_names + ) + except: + return load_saved_model(saved_model_dir, saved_model_tags, input_tensor_names, output_tensor_names) + return graph_def, input_names, output_names + + +def _get_graph_from_original_keras_v2(model, output_dir): + from tensorflow.lite.python.convert import OpsSet + from tensorflow.lite.python.util import ( + get_grappler_config, + model_input_signature, + run_graph_optimizations, + trace_model_call, + ) + from tensorflow.python.eager import def_function + from tensorflow.python.framework import convert_to_constants, dtypes + + input_signature = None + # If the model's call is not a `tf.function`, then we need to first get its + # input signature from `model_input_signature` method. + if not isinstance(model.call, def_function.Function): + input_signature = model_input_signature(model, keep_original_batch_size=False) + + func = trace_model_call(model, input_signature) + concrete_func = func.get_concrete_function() + funcs = [concrete_func] + + frozen_func, graph_def = convert_to_constants.convert_variables_to_constants_v2_as_graph( + funcs[0], lower_control_flow=False + ) + + input_tensors = [tensor for tensor in frozen_func.inputs if tensor.dtype != dtypes.resource] + output_tensors = frozen_func.outputs + # Grappler will also try to lower while loop into switch merge + # representation which is undesired for Ophints, so we simply remove + # those attributes to prevent Grappler from doing so. + graph = convert_to_constants.disable_lower_using_switch_merge(graph_def) + # Run function inlining optimization to ensure any models generated + # through the from_frozen_graph path have been inlined. + # grappler_config = get_grappler_config(['function']) + # graph_def = run_graph_optimizations( + # graph, + # input_tensors, + # output_tensors, + # config=grappler_config) + input_names = [tensor.name.split(":")[0] for tensor in input_tensors] + output_names = [tensor.name.split(":")[0] for tensor in output_tensors] + return graph_def, input_names, output_names + + +def _check_keras_format(model, saved_model_dir): + from tensorflow.python import saved_model + from tensorflow.python.saved_model import save_options + from tensorflow.python.saved_model.load import load + from tensorflow.python.saved_model.loader_impl import parse_saved_model_with_debug_info + + version = "saved_model_v2" + try: + saved_model.save(model, saved_model_dir, options=save_options.SaveOptions(save_debug_info=True)) + except: + return "trackable_object" + saved_model_proto, _ = parse_saved_model_with_debug_info(saved_model_dir) + saved_model_version = saved_model_proto.saved_model_schema_version + if saved_model_version == 0: + return "saved_model_v1" + if saved_model_version not in [1, 2]: + raise ValueError("SavedModel file format({0}) is not supported".format(saved_model_version)) + return version + + +def _get_graph_from_saved_model_v1(model): + from tensorflow.lite.python.convert_saved_model import get_inputs_outputs, get_meta_graph_def, get_signature_def + from tensorflow.python.client import session + from tensorflow.python.framework import ops + from tensorflow.python.saved_model import constants, signature_constants, tag_constants + + saved_model_tags = set([tag_constants.SERVING]) + signature_key = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY + + meta_graph = get_meta_graph_def(model, saved_model_tags) + signature_def = get_signature_def(meta_graph, signature_key) + inputs, outputs = get_inputs_outputs(signature_def) + # Check SavedModel for assets directory. + collection_def = meta_graph.collection_def + if constants.ASSETS_KEY in collection_def: + raise ValueError("SavedModels with assets/ directory are not supported.") + + from tensorflow.compat.v1 import graph_util as tf_graph_util + from tensorflow.python.saved_model import loader + + graph = ops.Graph() + import tensorflow as tf + + with session.Session(graph=graph) as sess: + loader.load(sess, meta_graph.meta_info_def.tags, model) + sess.run(tf.compat.v1.global_variables_initializer()) + sess.run(tf.compat.v1.tables_initializer()) + output_nodes = list(set([output.split(":")[0] for output in outputs])) + node_ops = [node.op for node in graph.as_graph_def().node] + if "MakeIterator" in node_ops: + output_nodes.append("MakeIterator") + table_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.TABLE_INITIALIZERS) + # For table initialization + for table_op in table_ops: + output_nodes.append(table_op.name) + if len(table_ops) > 0: + output_nodes.append("init_all_tables") + graph_def = tf_graph_util.convert_variables_to_constants(sess, graph.as_graph_def(), output_nodes) + return graph_def, inputs, outputs + + +def keras_session(model, input_tensor_names, output_tensor_names, **kwargs): + """Build session with keras model. + + Args: + model (string or tf.keras.Model): model path or tf.keras.Model object. + input_tensor_names (list of string): input_tensor_names of model. + output_tensor_names (list of string): output_tensor_names of model. + + Returns: + sess (tf.compat.v1.Session): tf.compat.v1.Session object. + input_tensor_names (list of string): validated input_tensor_names. + output_tensor_names (list of string): validated output_tensor_names. + """ + temp_dir = tempfile.mkdtemp() + if tf.version.VERSION > "2.1.0": + if not isinstance(model, tf.keras.Model): + model = tf.keras.models.load_model(model) + keras_format = _check_keras_format(model, temp_dir) + if keras_format == "saved_model_v2": + try: + graph_def, input_names, output_names = _get_graph_from_saved_model_v2( + temp_dir, input_tensor_names, output_tensor_names + ) + if "_FusedBatchNormEx" in [node.op for node in graph_def.node]: + keras_format = "trackable_object" + except: + keras_format = "trackable_object" + if keras_format == "trackable_object": + try: + graph_def, input_names, output_names = _get_graph_from_original_keras_v2(model, temp_dir) + except: + keras_format = "saved_model_v1" + if keras_format == "saved_model_v1": # pragma: no cover + try: + tf.keras.backend.set_learning_phase(0) + graph_def, input_names, output_names = _get_graph_from_saved_model_v1(model) + except: + raise ValueError("Not supported keras model type...") + + # tensorflow 1.x use v1 convert method + else: + tf.keras.backend.set_learning_phase(0) + graph_def, input_names, output_names = _get_graph_from_saved_model_v1(model) + shutil.rmtree(temp_dir, True) + return graph_def_session(graph_def, input_names, output_names, **kwargs) + + +def slim_session(model, input_tensor_names, output_tensor_names, **kwargs): # pragma: no cover + """Build session with slim model. + + Args: + model (string): model path. + input_tensor_names (list of string): input_tensor_names of model. + output_tensor_names (list of string): output_tensor_names of model. + + Returns: + sess (tf.compat.v1.Session): tf.compat.v1.Session object. + input_tensor_names (list of string): validated input_tensor_names. + output_tensor_names (list of string): validated output_tensor_names. + """ + assert version1_lt_version2(tf.version.VERSION, "2.0.0"), "slim model only used in tensorflow 1.x" + from neural_compressor.tensorflow.utils.nets_factory import TFSlimNetsFactory + + factory = TFSlimNetsFactory() + assert "name" in kwargs, "model name should be set in slim checkpoint...." + assert kwargs["name"] in factory.default_slim_models, "only support topology {}".format(factory.default_slim_models) + net = copy.deepcopy(factory.networks_map[kwargs["name"]]) + model_func = net.pop("model") + arg_scope = net.pop("arg_scope")() + inputs_shape = net.pop("input_shape") + kwargs = net + import tf_slim as slim + + with tf.Graph().as_default(): + images = tf.compat.v1.placeholder(name="input", dtype=tf.float32, shape=inputs_shape) + with tf.compat.v1.Session() as sess: + with slim.arg_scope(arg_scope) as scope: # pylint: disable=not-context-manager + model_func(images, is_training=False, **kwargs) + graph_def = sess.graph.as_graph_def() + output_tensor_names = output_tensor_names if len(output_tensor_names) > 0 else [graph_def.node[-1].name] + + from tensorflow.python.tools.freeze_graph import freeze_graph_with_def_protos + + graph_def = freeze_graph_with_def_protos( + input_graph_def=graph_def, + input_saver_def=None, + input_checkpoint=model, + output_node_names=",".join(output_tensor_names), + restore_op_name="save/restore_all", + filename_tensor_name="save/Const:0", + output_graph="", + clear_devices=True, + initializer_nodes="", + ) + + return graph_def_session(graph_def, ["input"], output_tensor_names, **kwargs) + + +def checkpoint_session(model, input_tensor_names, output_tensor_names, **kwargs): + """Build session with ckpt model. + + Args: + model (string): model path. + input_tensor_names (list of string): input_tensor_names of model. + output_tensor_names (list of string): validated output_tensor_names of model. + + Returns: + sess (tf.compat.v1.Session): tf.compat.v1.Session object. + input_tensor_names (list of string): validated input_tensor_names. + output_tensor_names (list of string): validated output_tensor_names. + """ + assert ( + output_tensor_names is not None and len(output_tensor_names) > 0 + ), "outputs should not be None of checkpoint...." + + ckpt_prefix = [os.path.splitext(i)[0] for i in os.listdir(model) if i.endswith(".meta")][0] + + config = tf.compat.v1.ConfigProto() + config.use_per_session_threads = 1 + config.inter_op_parallelism_threads = 1 + graph = tf.Graph() + sess = tf.compat.v1.Session(graph=graph, config=config) + if version1_lt_version2(tf.version.VERSION, "2.0.0"): # pragma: no cover + from tensorflow._api.v1.config import experimental + + list_physical_devices = experimental.list_physical_devices + else: + list_physical_devices = tf.config.list_physical_devices + + with graph.as_default(): + device = kwargs.get("device") + if device == "cpu": + cpus = list_physical_devices("CPU") + node_device = cpus[0].name.replace("physical_device:", "") + with graph.device(node_device): + saver = tf.compat.v1.train.import_meta_graph( + os.path.join(model, ckpt_prefix + ".meta"), clear_devices=True + ) + else: # pragma: no cover + saver = tf.compat.v1.train.import_meta_graph(os.path.join(model, ckpt_prefix + ".meta"), clear_devices=True) + + sess.run(tf.compat.v1.global_variables_initializer()) + saver.restore(sess, os.path.join(model, ckpt_prefix)) + + from neural_compressor.tensorflow.quantization.utils.utility import get_input_output_node_names + + if validate_graph_node(sess.graph.as_graph_def(), tensor_to_node(input_tensor_names)): + input_tensor_names = input_tensor_names + else: + input_tensor_names, _ = get_input_output_node_names(sess.graph.as_graph_def()) + return sess, input_tensor_names, output_tensor_names + + +def estimator_session(model, input_tensor_names, output_tensor_names, **kwargs): + """Build session with estimator model. + + Args: + model (tf.estimator.Estimator): tf.estimator.Estimator object. + input_tensor_names (list of string): input_tensor_names of model. + output_tensor_names (list of string): output_tensor_names of model. + kwargs (dict): other required parameters, like input_fn. + + Returns: + sess (tf.compat.v1.Session): tf.compat.v1.Session object. + input_tensor_names (list of string): validated input_tensor_names. + output_tensor_names (list of string): validated output_tensor_names. + """ + assert "input_fn" in kwargs, "input func should be supplied for estimator session...." + with tf.Graph().as_default() as g: + features, input_hooks = model._get_features_from_input_fn(kwargs["input_fn"], tf.estimator.ModeKeys.PREDICT) + estimator_spec = model._call_model_fn(features, None, tf.estimator.ModeKeys.PREDICT, model.config) + + if len(output_tensor_names) == 0: + outputs = ( + [tensor.name for tensor in estimator_spec.predictions.values()] + if isinstance(estimator_spec.predictions, dict) + else [estimator_spec.predictions.name] + ) + else: + outputs = output_tensor_names + + logger.info("Estimator output tensor names are {}.".format(outputs)) + with tf.compat.v1.Session(graph=g) as sess: + sess.run(tf.compat.v1.global_variables_initializer()) + # Freezing a graph requires output_node_names, which can be found in + # estimator_spec.predictions that contains prediction tensors as a + # dictionary + # When a model uses Iterator, we need to have 'MakeIterator' (default + # name used by TF) in the output_node_names as well. + output_nodes = list(set([output.split(":")[0] for output in outputs])) + if "MakeIterator" in [node.op for node in g.as_graph_def().node]: + output_nodes.append("MakeIterator") + + graph_def = tf.compat.v1.graph_util.convert_variables_to_constants(sess, g.as_graph_def(), output_nodes) + + return graph_def_session(graph_def, input_tensor_names, outputs, **kwargs) + + +def saved_model_session(model, input_tensor_names, output_tensor_names, **kwargs): + """Build session with saved model. + + Args: + model (string): model path. + input_tensor_names (list of string): input_tensor_names of model. + output_tensor_names (list of string): output_tensor_names of model. + + Returns: + sess (tf.compat.v1.Session): tf.compat.v1.Session object. + input_tensor_names (list of string): validated input_tensor_names. + output_tensor_names (list of string): validated output_tensor_names. + """ + try: + graph_def, input_names, output_names = _get_graph_from_saved_model_v2( + model, input_tensor_names, output_tensor_names + ) + except: + graph_def, input_names, output_names = _get_graph_from_saved_model_v1(model) + assert graph_def is not None, "Can not parse the saved model..." + return graph_def_session(graph_def, input_names, output_names, **kwargs) + + +# it's necessary that a session with input output tensors to run the model +SESSIONS = { + "frozen_pb": frozen_pb_session, + "graph_def": graph_def_session, + "graph": graph_session, + "saved_model": saved_model_session, + "llm_saved_model": saved_model_session, + "keras": keras_session, + "checkpoint": checkpoint_session, + "estimator": estimator_session, + "slim": slim_session, +} + + +class BaseModel: + """Base class of all neural_compressor.model, will play graph role.""" + + def __init__(self, model, **kwargs): + """Initialize a BaseModel. + + Args: + model (object): raw model format. For Tensorflow model, could be path to frozen pb file, + path to ckpt or savedmodel folder, loaded estimator/graph_def/graph/keras model object. + """ + self.component = None + + @property + def model(self): + """Return model itself.""" + raise NotImplementedError + + @property + def graph_info(self): + """Return a dict with content 'Node: Node_type'.""" + raise NotImplementedError + + @abstractmethod + def save(self, root, *args, **kwargs): + """Abstract method of model saving.""" + raise NotImplementedError + + @abstractmethod + def export( + self, + save_path: str, + conf, + ): + """Abstract method of model conversion to ONNX.""" + raise NotImplementedError + + @abstractmethod + def framework(self): + """Abstract method of model framework.""" + raise NotImplementedError + + +class TensorflowBaseModel(BaseModel): + """Build Tensorflow Base Model.""" + + def __init__(self, model, **kwargs): + """Initialize a Tensorflow model. + + Args: + model (string or tensorflow model object): model path or model object. + """ + self._model = model + self._name = "" + self._weights = None + self.kwargs = kwargs + self._graph_info = {} + self._input_tensor_names = [] + self._output_tensor_names = [] + self._model_type = "" + self._sess = None + self._iter_op = None + self._workspace_path = "" + self._q_config = None + self._model_path = None if not isinstance(model, str) else model + + @property + def model_path(self): + """Return model path.""" + return self._model_path + + @model_path.setter + def model_path(self, path): + """Set model path.""" + self._model_path = path + + def framework(self): + """Return framework.""" + return "tensorflow" + + @property + def name(self): + """Renturn name.""" + return self._name + + @name.setter + def name(self, name): + """Set name.""" + self.kwargs.update({"name": name}) + self._name = name + + @property + def weights(self): + """Return weights.""" + return self._weights + + @weights.setter + def weights(self, new_weights): + """Set weights.""" + self._weights = new_weights + + @property + def q_config(self): + """Return q_config.""" + return self._q_config + + @q_config.setter + def q_config(self, q_config): + """Set q_config.""" + self._q_config = q_config + + @property + def workspace_path(self): + """Return workspace path.""" + return self._workspace_path + + @workspace_path.setter + def workspace_path(self, path): + """Set workspace path.""" + self._workspace_path = path + + @property + def model_type(self): + """Return model type.""" + return self._model_type + + @model_type.setter + def model_type(self, model_type): + """Set model type.""" + assert model_type in SESSIONS, "model type not supported...." + self._model_type = model_type + + @property + def model(self): + """Return model itself.""" + return self.graph + + @property + def graph_def(self): + """Return graph definition.""" + return self.graph.as_graph_def() + + @property + def graph_info(self): + """Return graph info.""" + self._graph_info = {} + for node in self.graph_def.node: + self._graph_info[node.name] = node.op + return self._graph_info + + @property + def sess(self): + """Return Session object.""" + if self._sess is None: + self._load_sess(self._model, **self.kwargs) + return self._sess + + @property + def graph(self): + """Return model graph.""" + return self.sess.graph + + @graph_def.setter + def graph_def(self, graph_def): + """Set graph definition.""" + if self._sess is not None: + self._sess.close() + output_sess = SESSIONS["graph_def"]( + graph_def, self._input_tensor_names, self._output_tensor_names, **self.kwargs + ) + + self._sess = output_sess[0] + self._input_tensor_names = output_sess[1] + self._output_tensor_names = output_sess[2] + self.model_type = "graph_def" + + def _load_sess(self, model, **kwargs): + if self.name: + kwargs.update({"name": self.name}) + # assert self.model_type, 'model type not set....' + output_sess = SESSIONS[self.model_type](model, self._input_tensor_names, self._output_tensor_names, **kwargs) + self._sess = output_sess[0] + self._input_tensor_names = output_sess[1] + self._output_tensor_names = output_sess[2] + + tf.compat.v1.get_variable_scope().reuse_variables() + return self._sess + + @property + def iter_op(self): + """Return model iter op list.""" + self._iter_op = [] + if self._sess is None: + self._load_sess(self._model, **self.kwargs) + op_list = [node.op for node in self._sess.graph.as_graph_def().node] + if "MakeIterator" in op_list: + self._iter_op.append(self._sess.graph.get_operation_by_name("MakeIterator")) + return self._iter_op + + @property + def input_tensor_names(self): + """Return input tensor names.""" + if self._sess is None: + self._load_sess(self._model, **self.kwargs) + return copy.deepcopy(self._input_tensor_names) + + @input_tensor_names.setter + def input_tensor_names(self, tensor_names): + """Set input tensor names.""" + if len(tensor_names) == 0: + logger.warning("Input tensor names is empty.") + return + if self._sess is not None: + assert validate_graph_node( + self.graph_def, tensor_to_node(tensor_names) + ), "tensor names {} not in graph".format(tensor_names) + self._input_tensor_names = tensor_names + + @property + def output_tensor_names(self): + """Return output tensor names.""" + if len(self._output_tensor_names) == 0: + self._load_sess(self._model, **self.kwargs) + return copy.deepcopy(self._output_tensor_names) + + @output_tensor_names.setter + def output_tensor_names(self, tensor_names): + """Set output tensor names.""" + if len(tensor_names) == 0: + logger.warning("Output tensor names should not be empty.") + return + if self._sess is not None: + assert validate_graph_node( + self.graph_def, tensor_to_node(tensor_names) + ), "tensor names {} not in graph".format(tensor_names) + self._output_tensor_names = tensor_names + + # input/output node names and input/output tensor + # come from input/output tensor names, so do not support assign these values + @property + def input_node_names(self): + """Return input node names.""" + return copy.deepcopy(tensor_to_node(self.input_tensor_names)) + + @property + def output_node_names(self): + """Return output node names.""" + output_node_names = tensor_to_node(self.output_tensor_names) + iter_op_list = self.iter_op + if iter_op_list != []: + output_node_names += [iter_op.name for iter_op in iter_op_list] + return copy.deepcopy(output_node_names) + + @property + def input_tensor(self): + """Return input tensor.""" + from neural_compressor.tensorflow.quantization.utils.utility import get_tensor_by_name + + return [get_tensor_by_name(self.graph, x) for x in self.input_tensor_names] + + @property + def output_tensor(self): + """Return output tensor.""" + from neural_compressor.tensorflow.quantization.utils.utility import get_tensor_by_name + + return [get_tensor_by_name(self.graph, x) for x in self.output_tensor_names] + + def save(self, root=None): + """Save Tensorflow model.""" + if not root: + root = DEFAULT_WORKSPACE + "/save.pb" + root = os.path.abspath(os.path.expanduser(root)) + # if not have suffix, default append .pb + os.makedirs(os.path.dirname(root), exist_ok=True) + pb_file = root if os.path.split(root)[-1].endswith(".pb") else root + ".pb" + f = tf.io.gfile.GFile(pb_file, "wb") + f.write(self.graph_def.SerializeToString()) + logger.info("Save quantized model to {}.".format(pb_file)) + + +class TensorflowSavedModelModel(TensorflowBaseModel): + """Build Tensorflow saved model.""" + + def __init__(self, model, **kwargs): + """Initialize a Tensorflow model. + + Args: + model (string or tensorflow model object): model path or model object. + """ + super(TensorflowSavedModelModel, self).__init__(model, **kwargs) + self._auto_trackable = None + + def get_all_weight_names(self): + """Get weight names of model. + + Returns: + list: weight names list. + """ + import tensorflow as tf + + names = [] + for index, layer in enumerate(tf.keras.models.load_model(self._model).layers): + if len(layer.weights): + names.append(index) + return names + + def update_weights(self, tensor_name, new_tensor): + """Update model weights.""" + pass + + def get_weight(self, tensor_name): + """Return model weight with a given tensor name. + + Args: + tensor_name (str): name of a tensor. + """ + return self.weights[tensor_name] + + @property + def model(self): + """Return model in AutoTrackable object.""" + if self._auto_trackable: + return self._auto_trackable + + root = os.path.abspath(os.path.expanduser(DEFAULT_WORKSPACE)) + root += str(time.time()) + if os.path.exists(root): + shutil.rmtree(root) + os.makedirs(root, exist_ok=True) + if not self._sess: + self._load_sess(self._model, **self.kwargs) + _, builder = self.build_saved_model(root) + builder.save() + model = tf.saved_model.load(root) + shutil.rmtree(root) + self._auto_trackable = model + return model + + @model.setter + def model(self, input_model): + """Set model in AutoTrackable object.""" + self._auto_trackable = input_model + + def compute_sparsity(self, tensor): + """Compute the sparsity. + + Args: + tensor: Tensorflow tensor + + Return: + (the original tensor size, number of zero elements, number of non-zero elements) + """ + mask = np.ones_like(tensor) + tensor_size = tensor.size + dense_mask = tensor != 0 + dense_size = dense_mask.sum() + return tensor_size, tensor_size - dense_size, dense_size + + def report_sparsity(self): + """Get sparsity of the model. + + Returns: + df (DataFrame): DataFrame of sparsity of each weight. + total_sparsity (float): total sparsity of model. + """ + import numpy as np + import pandas as pd + import tensorflow as tf + + df = pd.DataFrame(columns=["Name", "Shape", "NNZ (dense)", "NNZ (sparse)", "Sparsity(%)"]) + pd.set_option("display.precision", 2) + param_dims = [2, 4] + params_size = 0 + sparse_params_size = 0 + for index, layer in enumerate(tf.keras.models.load_model(self._model).layers): + if not len(layer.weights): + continue + # Extract just the actual parameter's name, which in this context we treat + # as its "type" + weights = layer.get_weights()[0] + if weights.ndim in param_dims: + param_size, sparse_param_size, dense_param_size = self.compute_sparsity(weights) + density = dense_param_size / param_size + params_size += param_size + sparse_params_size += sparse_param_size + df.loc[len(df.index)] = [ + index, + list(weights.shape), + dense_param_size, + sparse_param_size, + (1 - density) * 100, + ] + + total_sparsity = sparse_params_size / params_size * 100 + + df.loc[len(df.index)] = [ + "Total sparsity:", + "-", + params_size, + sparse_params_size, + total_sparsity, + ] + + return df, total_sparsity + + def build_saved_model(self, root=None): + """Build Tensorflow saved model. + + Args: + root (str, optional): path to saved model. Defaults to None. + + Returns: + root (str): path to saved model. + builder (tf.compat.v1.saved_model.builder.SavedModelBuilder): builds + the SavedModel protocol buffer and saves variables and assets. + """ + if not root: + root = DEFAULT_WORKSPACE + root = os.path.abspath(os.path.expanduser(root)) + if os.path.exists(root): + import shutil + + shutil.rmtree(root) + + os.makedirs(root, exist_ok=True) + + from tensorflow.python.saved_model import signature_constants, tag_constants + + from neural_compressor.tensorflow.quantization.utils.utility import get_tensor_by_name + + builder = tf.compat.v1.saved_model.builder.SavedModelBuilder(root) + sigs = {} + with tf.compat.v1.Session(graph=tf.Graph()) as sess: + # (TODO) not directly use self._sess.graph, use self.graph + tf.import_graph_def(self.graph.as_graph_def(), name="") + g = tf.compat.v1.get_default_graph() + inp = [get_tensor_by_name(g, x) for x in self._input_tensor_names] + out = [get_tensor_by_name(g, x) for x in self._output_tensor_names] + sigs[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = ( + tf.compat.v1.saved_model.signature_def_utils.predict_signature_def( + {k: v for k, v in zip(self._input_tensor_names, inp)}, + {k: v for k, v in zip(self._output_tensor_names, out)}, + ) + ) + builder.add_meta_graph_and_variables(sess, [tag_constants.SERVING], signature_def_map=sigs) + return root, builder + + def save(self, root=None): + """Save Tensorflow model.""" + root, builder = self.build_saved_model(root) + builder.save() + logger.info("Save quantized model to {}.".format(root)) + + +class TensorflowLLMModel(TensorflowSavedModelModel): + """The class Tensorflow saved model whose GraphDef exceeding maximum protobuf size of 2GB.""" + + def __init__(self, model, **kwargs): + """Initialize a Tensorflow model. + + Args: + model (string or tensorflow model object): model path or model object. + """ + super(TensorflowLLMModel, self).__init__(model, **kwargs) + + self._model_path = self.kwargs.get("model_path", None) + self._weight_name_mapping = self.kwargs.get("weight_name_mapping", None) + self._sq_weight_scale_dict = self.kwargs.get("sq_weight_scale_dict", None) + self._weight_tensor_minmax_dict = {} + self._model_type = "llm_saved_model" + + from neural_compressor.tensorflow.quantization.utils.utility import parse_saved_model + + ( + self._graph_def, + self._saved_model, + self.func, + self.frozen_func, + self._input_tensor_names, + self._output_tensor_names, + ) = parse_saved_model(model) + + @property + def model_path(self): + """Return model path. + + The model path in this class is used as a temp path for intermediate model + """ + return self._model_path + + @model_path.setter + def model_path(self, path): + """Set model path. + + The model path in this class is used as a temp path for intermediate model + """ + self.kwargs.update({"model_path": path}) + self._model_path = path + + @property + def graph_def(self): + """Return graph_def.""" + return self._graph_def + + @graph_def.setter + def graph_def(self, graph_def): + """Set graph definition.""" + self._graph_def = graph_def + # the attributes of some nodes can't be correctly read if don't import the graph_def + tf.import_graph_def(self._graph_def, name="") + + @property + def model(self): + """Return model in AutoTrackable Format.""" + if self._sq_weight_scale_dict: + self.adjust_weight(self.graph_def) + if not self._auto_trackable: + self._auto_trackable = tf.saved_model.load(self._model) + return self._auto_trackable + + @property + def weight_name_mapping(self): + """Return weight_name_mapping function.""" + if not self._weight_name_mapping: + self._weight_name_mapping = self.kwargs.get("weight_name_mapping", None) + assert self._weight_name_mapping is not None, "weight_name_mapping should not be None!" + return self._weight_name_mapping + + @weight_name_mapping.setter + def weight_name_mapping(self, weight_name_mapping): + """Set weight_name_mapping function.""" + self.kwargs.update({"weight_name_mapping": weight_name_mapping}) + self._weight_name_mapping = weight_name_mapping + + @property + def sq_weight_scale_dict(self): + """Return dict of weight scaler for smooth quantization.""" + if not self._sq_weight_scale_dict: + self._sq_weight_scale_dict = self.kwargs.get("sq_weight_scale_dict", None) + assert self._weight_name_mapping is not None, "sq_weight_scale_dict should not be None!" + return self._sq_weight_scale_dict + + @sq_weight_scale_dict.setter + def sq_weight_scale_dict(self, sq_weight_scale_dict): + """Set dict of weight scaler for smooth quantization.""" + self.kwargs.update({"sq_weight_scale_dict": sq_weight_scale_dict}) + self._sq_weight_scale_dict = sq_weight_scale_dict + + @property + def weight_tensor_minmax_dict(self): + """Return dict of weight scaler for smooth quantization.""" + return self._weight_tensor_minmax_dict + + @property + def input_tensor_names(self): + """Return input tensor names.""" + return copy.deepcopy(self._input_tensor_names) + + @input_tensor_names.setter + def input_tensor_names(self, tensor_names): + """Set input tensor names.""" + if len(tensor_names) == 0: # pragma: no cover + logger.warning("Input tensor names is empty.") + return + + assert validate_graph_node( + self._graph_def, tensor_to_node(tensor_names) + ), "tensor names {} not in graph".format(tensor_names) + self._input_tensor_names = tensor_names + + @property + def output_tensor_names(self): + """Return output tensor names.""" + return copy.deepcopy(self._output_tensor_names) + + @output_tensor_names.setter + def output_tensor_names(self, tensor_names): + """Set output tensor names.""" + if len(tensor_names) == 0: # pragma: no cover + logger.warning("Output tensor names is empty.") + return + if self._graph_def is not None: + assert validate_graph_node( + self.graph_def, tensor_to_node(tensor_names) + ), "tensor names {} not in graph".format(tensor_names) + self._output_tensor_names = tensor_names + + @property + def output_node_names(self): + """Return output node names.""" + output_node_names = tensor_to_node(self.output_tensor_names) + return copy.deepcopy(output_node_names) + + def adjust_weight(self, graph_def): + """Adjust weight of LLM saved_model by scale.""" + from tensorflow.python.saved_model import load, tag_constants + + from neural_compressor.tensorflow.quantization.utils.utility import reconstruct_saved_model + + reconstruct_saved_model(graph_def, self.func, self.frozen_func, self._saved_model, self.model_path) + model = load.load(self.model_path, [tag_constants.SERVING]) + + for idx, weight_tensor in enumerate(model.variables): + parsed_weight_name = self.weight_name_mapping(weight_tensor.name) + if parsed_weight_name in self.sq_weight_scale_dict: + weight_array = np.transpose(weight_tensor, [1, 0]) + weight_array *= self.sq_weight_scale_dict[parsed_weight_name] + weight_array = np.transpose(weight_array, [1, 0]) + tf.compat.v1.assign(model.variables[idx], weight_array) + else: + weight_array = weight_tensor + + if parsed_weight_name not in self._weight_tensor_minmax_dict: + self._weight_tensor_minmax_dict[parsed_weight_name] = [np.min(weight_array), np.max(weight_array)] + self._auto_trackable = model + + def save(self, root=None): + """Save the model to the root path.""" + import shutil + + from neural_compressor.tensorflow.quantization.utils.utility import parse_saved_model, reconstruct_saved_model + + if not root: + root = DEFAULT_WORKSPACE + root = os.path.abspath(os.path.expanduser(root)) + if os.path.exists(root): + shutil.rmtree(root) + os.makedirs(root, exist_ok=True) + + self.adjust_weight(self._graph_def) + graph_def, _saved_model, func, frozen_func, _, _ = parse_saved_model(self._auto_trackable) + reconstruct_saved_model(graph_def, func, frozen_func, _saved_model, root) + logger.info("Save quantized model to {}.".format(root)) + # delete the LLM file saved in this temporary path + shutil.rmtree(self.model_path, ignore_errors=True) + + +class TensorflowQATModel(TensorflowSavedModelModel): + """Build Tensorflow QAT model.""" + + def __init__(self, model="", **kwargs): + """Initialize a Tensorflow QAT model. + + Args: + model (string or tf.keras.Model object): model path or model object. + """ + assert isinstance(model, tf.keras.Model) or isinstance( + model, str + ), "The TensorflowQATModel should be initialized either by a string or a tf.keras.Model." + super(TensorflowQATModel, self).__init__(model) + self.keras_model = None + self.model_type = "keras" + + @property + def model(self): + """Return model itself.""" + if self.keras_model is None: + if isinstance(self._model, tf.keras.Model): + self.keras_model = self._model + else: + self.keras_model = tf.keras.models.load_model(self._model) + + return self.keras_model + + @model.setter + def model(self, q_model): + """Set model itself.""" + self.keras_model = q_model + + @property + def frozen_graph_def(self): + """Get frozen graph_def.""" + graph_def = tf.compat.v1.graph_util.convert_variables_to_constants( + self.sess, self.sess.graph_def, self.output_node_names + ) + return graph_def + + def save(self, root=None): + """Save Tensorflow QAT model.""" + if not root: + root = DEFAULT_WORKSPACE + "/saved_model" + root = os.path.abspath(os.path.expanduser(root)) + os.makedirs(os.path.dirname(root), exist_ok=True) + if root.endswith(".pb"): + saved_format = "pb file" + graph_def = self.frozen_graph_def + f = tf.io.gfile.GFile(root, "wb") + f.write(graph_def.SerializeToString()) + else: + q_aware_model = self.keras_model + q_aware_model.save(root) + saved_format = "saved_model" + if root.endswith(".h5"): + saved_format = "h5 file" + logger.info("Save quantized model to {}.".format(saved_format)) + return root + + +class TensorflowCheckpointModel(TensorflowBaseModel): + """Build Tensorflow checkpoint model.""" + + @property + def graph_def(self): + """Return graph definition.""" + if self.model_type == "graph_def": + return self.sess.graph.as_graph_def() + from tensorflow.compat.v1 import graph_util + + from neural_compressor.tensorflow.quantization.utils.utility import _parse_ckpt_bn_input + + graph_def = self.sess.graph.as_graph_def() + graph_def = _parse_ckpt_bn_input(graph_def) + return graph_util.convert_variables_to_constants( + sess=self._sess, input_graph_def=graph_def, output_node_names=self.output_node_names + ) + + @graph_def.setter + def graph_def(self, graph_def): + """Set graph definition.""" + if self._sess is not None: + self._sess.close() + output_sess = SESSIONS["graph_def"]( + graph_def, self._input_tensor_names, self._output_tensor_names, **self.kwargs + ) + self._sess = output_sess[0] + self._input_tensor_names = output_sess[1] + self._output_tensor_names = output_sess[2] + self.model_type = "graph_def" + + @property + def model(self): + """Return the model itself to avoid the initialization issue.""" + return self + + +class KerasModel(BaseModel): + """Build Keras model.""" + + def __init__(self, model, **kwargs): + """Initialize a Keras model. + + Args: + model (string or keras model object): model path or model object. + """ + self.component = None + self._model = model + if not isinstance(model, tf.keras.Model): + self._model_object = tf.keras.models.load_model(self._model) + else: + self._model_object = self._model + self._q_config = None + + @property + def q_config(self): + """Return q_config.""" + return self._q_config + + @q_config.setter + def q_config(self, q_config): + """Set q_config.""" + self._q_config = q_config + + @property + def model(self): + """Return model itself.""" + return self._model_object + + @property + def graph_info(self): + """Return graph info.""" + # (TODO) get the graph info + return None + + @abstractmethod + def save(self, root, *args, **kwargs): + """Save Keras model.""" + self._model_object.save(root) + + @abstractmethod + def _export( + self, + save_path: str, + conf, + ): + pass + + @abstractmethod + def framework(self): + """Return framework.""" + return "keras" + + def get_all_weight_names(self): + """Get weight names of model. + + Returns: + list: weight names list. + """ + names = [] + for index, layer in enumerate(self.model.layers): + if len(layer.weights): + names.append(index) + return names + + def compute_sparsity(self, tensor): + """Compute the sparsity. + + Args: + tensor: Tensorflow tensor + + Return: + (the original tensor size, number of zero elements, number of non-zero elements) + """ + mask = np.ones_like(tensor) + tensor_size = tensor.size + dense_mask = tensor != 0 + dense_size = dense_mask.sum() + return tensor_size, tensor_size - dense_size, dense_size + + def report_sparsity(self): + """Get sparsity of the model. + + Returns: + df (DataFrame): DataFrame of sparsity of each weight. + total_sparsity (float): total sparsity of model. + """ + import numpy as np + import pandas as pd + import tensorflow as tf + + df = pd.DataFrame(columns=["Name", "Shape", "NNZ (dense)", "NNZ (sparse)", "Sparsity(%)"]) + pd.set_option("display.precision", 2) + param_dims = [2, 4] + params_size = 0 + sparse_params_size = 0 + for index, layer in enumerate(self.model.layers): + if not len(layer.weights): + continue + # Extract just the actual parameter's name, which in this context we treat + # as its "type" + weights = layer.get_weights()[0] + if weights.ndim in param_dims: + param_size, sparse_param_size, dense_param_size = self.compute_sparsity(weights) + density = dense_param_size / param_size + params_size += param_size + sparse_params_size += sparse_param_size + df.loc[len(df.index)] = [ + index, + list(weights.shape), + dense_param_size, + sparse_param_size, + (1 - density) * 100, + ] + + total_sparsity = sparse_params_size / params_size * 100 + + df.loc[len(df.index)] = [ + "Total sparsity:", + "-", + params_size, + sparse_params_size, + total_sparsity, + ] + + return df, total_sparsity + + @property + def input_node_names(self): + """Return input node names.""" + return self.model.input_names + + @property + def output_node_names(self): + """Return output node names.""" + return self.model.output_names + + +TENSORFLOW_MODELS = { + "frozen_pb": TensorflowBaseModel, + "graph_def": TensorflowBaseModel, + "graph": TensorflowBaseModel, + "checkpoint": TensorflowCheckpointModel, + "estimator": TensorflowBaseModel, + "slim": TensorflowBaseModel, + "saved_model": TensorflowSavedModelModel, + "AutoTrackable": TensorflowSavedModelModel, + "llm_saved_model": TensorflowLLMModel, + "keras": KerasModel, + "keras_qat": TensorflowQATModel, +} + + +class TensorflowModel(object): + """A wrapper to construct a Tensorflow Model.""" + + def __new__(cls, model_type, root, **kwargs): + """Create a new instance object of TensorflowModel. + + Args: + model_type (str): model type. + root (str): model path. + + Returns: + tensorflow model object: tensorflow model. + """ + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + os.environ["CUDA_VISIBLE_DEVICES"] = "-1" + model = TENSORFLOW_MODELS[model_type](root, **kwargs) + model.model_type = model_type + return model diff --git a/neural_compressor/tensorflow/utils/nets_factory.py b/neural_compressor/tensorflow/utils/nets_factory.py new file mode 100644 index 00000000000..d09ef4ba1d1 --- /dev/null +++ b/neural_compressor/tensorflow/utils/nets_factory.py @@ -0,0 +1,161 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2021 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""TF-Slim nets factory.""" + +from neural_compressor.tensorflow.utils.utility import singleton + + +@singleton +class TFSlimNetsFactory(object): + """TF-Slim nets factory.""" + + def __init__(self): + """Initialize a TFSlimNetsFactory.""" + # tf_slim only support specific models by default + self.default_slim_models = [ + "alexnet_v2", + "overfeat", + "vgg_a", + "vgg_16", + "vgg_19", + "inception_v1", + "inception_v2", + "inception_v3", + "resnet_v1_50", + "resnet_v1_101", + "resnet_v1_152", + "resnet_v1_200", + "resnet_v2_50", + "resnet_v2_101", + "resnet_v2_152", + "resnet_v2_200", + ] + + from tf_slim.nets import alexnet, inception, overfeat, resnet_v1, resnet_v2, vgg + + self.networks_map = { + "alexnet_v2": { + "model": alexnet.alexnet_v2, + "input_shape": [None, 224, 224, 3], + "num_classes": 1001, + "arg_scope": alexnet.alexnet_v2_arg_scope, + }, + "overfeat": { + "model": overfeat.overfeat, + "input_shape": [None, 224, 224, 3], + "num_classes": 1001, + "arg_scope": overfeat.overfeat_arg_scope, + }, + "vgg_a": { + "model": vgg.vgg_a, + "input_shape": [None, 224, 224, 3], + "num_classes": 1000, + "arg_scope": vgg.vgg_arg_scope, + }, + "vgg_16": { + "model": vgg.vgg_16, + "input_shape": [None, 224, 224, 3], + "num_classes": 1000, + "arg_scope": vgg.vgg_arg_scope, + }, + "vgg_19": { + "model": vgg.vgg_19, + "input_shape": [None, 224, 224, 3], + "num_classes": 1000, + "arg_scope": vgg.vgg_arg_scope, + }, + "inception_v1": { + "model": inception.inception_v1, + "input_shape": [None, 224, 224, 3], + "num_classes": 1001, + "arg_scope": inception.inception_v1_arg_scope, + }, + "inception_v2": { + "model": inception.inception_v2, + "input_shape": [None, 224, 224, 3], + "num_classes": 1001, + "arg_scope": inception.inception_v2_arg_scope, + }, + "inception_v3": { + "model": inception.inception_v3, + "input_shape": [None, 299, 299, 3], + "num_classes": 1001, + "arg_scope": inception.inception_v3_arg_scope, + }, + "resnet_v1_50": { + "model": resnet_v1.resnet_v1_50, + "input_shape": [None, 224, 224, 3], + "num_classes": 1000, + "arg_scope": resnet_v1.resnet_arg_scope, + }, + "resnet_v1_101": { + "model": resnet_v1.resnet_v1_101, + "input_shape": [None, 224, 224, 3], + "num_classes": 1000, + "arg_scope": resnet_v1.resnet_arg_scope, + }, + "resnet_v1_152": { + "model": resnet_v1.resnet_v1_152, + "input_shape": [None, 224, 224, 3], + "num_classes": 1000, + "arg_scope": resnet_v1.resnet_arg_scope, + }, + "resnet_v1_200": { + "model": resnet_v1.resnet_v1_200, + "input_shape": [None, 224, 224, 3], + "num_classes": 1000, + "arg_scope": resnet_v1.resnet_arg_scope, + }, + "resnet_v2_50": { + "model": resnet_v2.resnet_v2_50, + "input_shape": [None, 224, 224, 3], + "num_classes": 1001, + "arg_scope": resnet_v2.resnet_arg_scope, + }, + "resnet_v2_101": { + "model": resnet_v2.resnet_v2_101, + "input_shape": [None, 224, 224, 3], + "num_classes": 1001, + "arg_scope": resnet_v2.resnet_arg_scope, + }, + "resnet_v2_152": { + "model": resnet_v2.resnet_v2_152, + "input_shape": [None, 224, 224, 3], + "num_classes": 1001, + "arg_scope": resnet_v2.resnet_arg_scope, + }, + "resnet_v2_200": { + "model": resnet_v2.resnet_v2_200, + "input_shape": [None, 224, 224, 3], + "num_classes": 1001, + "arg_scope": resnet_v2.resnet_arg_scope, + }, + } + + def register(self, name, model_func, input_shape, arg_scope, **kwargs): + """Register a model to TFSlimNetsFactory. + + Args: + name (str): name of a model. + model_func (_type_): model that built from slim. + input_shape (_type_): input tensor shape. + arg_scope (_type_): slim arg scope that needed. + """ + net_info = {"model": model_func, "input_shape": input_shape, "arg_scope": arg_scope} + net = {name: {**net_info, **kwargs}} + self.networks_map.update(net) + self.default_slim_models.append(name) diff --git a/neural_compressor/tensorflow/utils/utility.py b/neural_compressor/tensorflow/utils/utility.py new file mode 100644 index 00000000000..ed1fc88aee8 --- /dev/null +++ b/neural_compressor/tensorflow/utils/utility.py @@ -0,0 +1,417 @@ +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import logging +import os +import pickle +import subprocess +import sys +import time +from functools import reduce +from typing import Callable, Dict + +import cpuinfo +import numpy as np +import prettytable as pt +import psutil +from pkg_resources import parse_version + +from neural_compressor.common import logger + +# Dictionary to store a mapping between algorithm names and corresponding algo implementation(function) +algos_mapping: Dict[str, Callable] = {} + + +def version1_lt_version2(version1, version2): + """Check whether version1 is less than version2.""" + return parse_version(version1) < parse_version(version2) + + +def version1_gt_version2(version1, version2): + """Check whether version1 is greater than version2.""" + return parse_version(version1) > parse_version(version2) + + +def version1_eq_version2(version1, version2): + """Check whether version1 is equal to version2.""" + return parse_version(version1) == parse_version(version2) + + +def version1_gte_version2(version1, version2): + """Check whether version1 is greater than version2 or is equal to it.""" + return parse_version(version1) > parse_version(version2) or parse_version(version1) == parse_version(version2) + + +def version1_lte_version2(version1, version2): + """Check whether version1 is less than version2 or is equal to it.""" + return parse_version(version1) < parse_version(version2) or parse_version(version1) == parse_version(version2) + + +def register_algo(name): + """Decorator function to register algorithms in the algos_mapping dictionary. + + Usage example: + @register_algo(name=example_algo) + def example_algo(model: tf.keras.Model, quant_config: StaticQuantConfig) -> tf.keras.Model: + ... + Args: + name (str): The name under which the algorithm function will be registered. + Returns: + decorator: The decorator function to be used with algorithm functions. + """ + + def decorator(algo_func): + algos_mapping[name] = algo_func + return algo_func + + return decorator + + +def deep_get(dictionary, keys, default=None): + """Get the dot key's item in nested dict + eg person = {'person':{'name':{'first':'John'}}} + deep_get(person, "person.name.first") will output 'John'. + + Args: + dictionary (dict): The dict object to get keys + keys (dict): The deep keys + default (object): The return item if key not exists + Returns: + item: the item of the deep dot keys + """ + return reduce(lambda d, key: d.get(key, default) if isinstance(d, dict) else default, keys.split("."), dictionary) + + +def itex_installed(): + """Check if the IntelĀ® Extension for TensorFlow has been installed.""" + try: + import intel_extension_for_tensorflow + + return True + except: + return False + + +def dump_elapsed_time(customized_msg=""): + """Get the elapsed time for decorated functions. + + Args: + customized_msg (string, optional): The parameter passed to decorator. Defaults to None. + """ + + def f(func): + def fi(*args, **kwargs): + start = time.time() + res = func(*args, **kwargs) + end = time.time() + logging.getLogger("neural_compressor").info( + "%s elapsed time: %s ms" + % (customized_msg if customized_msg else func.__qualname__, round((end - start) * 1000, 2)) + ) + return res + + return fi + + return f + + +def combine_histogram(old_hist, arr): + """Collect layer histogram for arr and combine it with old histogram.""" + new_max = np.max(arr) + new_min = np.min(arr) + new_th = max(abs(new_min), abs(new_max)) + (old_hist, old_hist_edges, old_min, old_max, old_th) = old_hist + if new_th <= old_th: + hist, _ = np.histogram(arr, bins=len(old_hist), range=(-old_th, old_th)) + return (old_hist + hist, old_hist_edges, min(old_min, new_min), max(old_max, new_max), old_th) + else: + old_num_bins = len(old_hist) + old_step = 2 * old_th / old_num_bins + half_increased_bins = int((new_th - old_th) // old_step + 1) + new_num_bins = half_increased_bins * 2 + old_num_bins + new_th = half_increased_bins * old_step + old_th + hist, hist_edges = np.histogram(arr, bins=new_num_bins, range=(-new_th, new_th)) + hist[half_increased_bins : new_num_bins - half_increased_bins] += old_hist + return (hist, hist_edges, min(old_min, new_min), max(old_max, new_max), new_th) + + +def get_all_fp32_data(data): + """Get all the fp32 data.""" + return [float(i) for i in data.replace("[", " ").replace("]", " ").split(" ") if i.strip() and len(i) < 32] + + +def get_tensor_histogram(tensor_data, bins=2048): + """Get the histogram of the tensor data.""" + max_val = np.max(tensor_data) + min_val = np.min(tensor_data) + th = max(abs(min_val), abs(max_val)) + hist, hist_edges = np.histogram(tensor_data, bins=2048, range=(-th, th)) + return (hist, hist_edges, min_val, max_val, th) + + +def Dequantize(data, scale_info): + """Dequantize the data with the scale_info.""" + original_shape = data.shape + max_value = 255.0 if scale_info[0].find("Relu") != -1.0 else 127.0 + _scale = (np.array(scale_info[2]) - np.array(scale_info[1])) / max_value + de_scale = np.ones(original_shape) * _scale + de_data = np.multiply(data, de_scale).astype(np.float32) + return de_data + + +def dequantize_weight(weight_tensor, min_filter_tensor, max_filter_tensor): + """Dequantize the weight with min-max filter tensors.""" + weight_channel = weight_tensor.shape[-1] + if len(min_filter_tensor) == 1: + weight_tensor = weight_tensor * ((max_filter_tensor[0] - min_filter_tensor[0]) / 127.0) + else: + # TODO to calculate the de-quantized result in a parallel way + for i in range(weight_channel): + weight_tensor[:, :, :, i] = weight_tensor[:, :, :, i] * ( + (max_filter_tensor[i] - min_filter_tensor[i]) / 127.0 + ) + return weight_tensor + + +def dump_data_to_local(data, path, filename): + """Dump data to local as pkl file. + + Args: + data: Data used to dump + path: The directory to save data + filename: The filename to dump + + Returns: + loaded data + """ + from pathlib import Path + + if not os.path.exists(path): + Path(path).mkdir(parents=True, exist_ok=True) + file_path = os.path.join(path, filename) + with open(file_path, "wb") as fp: + pickle.dump(data, fp) + logging.getLogger("neural_compressor").info("Dumped data to %s" % file_path) + + +def load_data_from_pkl(path, filename): + """Load data from local pkl file. + + Args: + path: The directory to load data + filename: The filename to load + """ + try: + file_path = os.path.join(path, filename) + with open(file_path, "rb") as fp: + data = pickle.load(fp) + return data + except FileExistsError: + logging.getLogger("neural_compressor").info("Can not open %s." % path) + + +def singleton(cls): + """Not displayed in API Docs. + + Singleton decorator. + """ + instances = {} + + def _singleton(*args, **kw): + """Create a singleton object.""" + if cls not in instances: + instances[cls] = cls(*args, **kw) + return instances[cls] + + return _singleton + + +def disable_random(seed=1): + """A Decorator to disable tf random seed.""" + import tensorflow as tf + + def decorator(func): + def wrapper(*args, **kw): + tf.compat.v1.disable_eager_execution() + tf.compat.v1.reset_default_graph() + tf.compat.v1.set_random_seed(seed) + return func(*args, **kw) + + return wrapper + + return decorator + + +@singleton +class CpuInfo(object): + """Get CPU Info.""" + + def __init__(self): + """Get whether the cpu numerical format is bf16, the number of sockets, cores and cores per socket.""" + self._bf16 = False + self._vnni = False + info = cpuinfo.get_cpu_info() + if "arch" in info and "X86" in info["arch"]: + cpuid = cpuinfo.CPUID() + max_extension_support = cpuid.get_max_extension_support() + if max_extension_support >= 7: + ecx = cpuid._run_asm( + b"\x31\xC9", # xor ecx, ecx + b"\xB8\x07\x00\x00\x00" b"\x0f\xa2" b"\x89\xC8" b"\xC3", # mov eax, 7 # cpuid # mov ax, cx # ret + ) + self._vnni = bool(ecx & (1 << 11)) + eax = cpuid._run_asm( + b"\xB9\x01\x00\x00\x00", # mov ecx, 1 + b"\xB8\x07\x00\x00\x00" b"\x0f\xa2" b"\xC3", # mov eax, 7 # cpuid # ret + ) + self._bf16 = bool(eax & (1 << 5)) + if "arch" in info and "ARM" in info["arch"]: # pragma: no cover + self._sockets = 1 + else: + self._sockets = self.get_number_of_sockets() + self._cores = psutil.cpu_count(logical=False) + self._cores_per_socket = int(self._cores / self._sockets) + + @property + def bf16(self): + """Get whether it is bf16.""" + return self._bf16 + + @property + def vnni(self): + """Get whether it is vnni.""" + return self._vnni + + @property + def cores_per_socket(self): + """Get the cores per socket.""" + return self._cores_per_socket + + def get_number_of_sockets(self) -> int: + """Get number of sockets in platform.""" + cmd = "cat /proc/cpuinfo | grep 'physical id' | sort -u | wc -l" + if psutil.WINDOWS: + cmd = r'wmic cpu get DeviceID | C:\Windows\System32\find.exe /C "CPU"' + + with subprocess.Popen( + args=cmd, + shell=True, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + universal_newlines=False, + ) as proc: + proc.wait() + if proc.stdout: + for line in proc.stdout: + return int(line.decode("utf-8", errors="ignore").strip()) + return 0 + + +class Statistics: + """The statistics printer.""" + + def __init__(self, data, header, field_names, output_handle=logger.info): + """Init a Statistics object. + + Args: + data: The statistics data + header: The table header + field_names: The field names + output_handle: The output logging method + """ + self.field_names = field_names + self.header = header + self.data = data + self.output_handle = output_handle + self.tb = pt.PrettyTable(min_table_width=40) + + def print_stat(self): + """Print the statistics.""" + valid_field_names = [] + for index, value in enumerate(self.field_names): + if index < 2: + valid_field_names.append(value) + continue + + if any(i[index] for i in self.data): + valid_field_names.append(value) + self.tb.field_names = valid_field_names + for i in self.data: + tmp_data = [] + for index, value in enumerate(i): + if self.field_names[index] in valid_field_names: + tmp_data.append(value) + if any(tmp_data[1:]): + self.tb.add_row(tmp_data) + lines = self.tb.get_string().split("\n") + self.output_handle("|" + self.header.center(len(lines[0]) - 2, "*") + "|") + for i in lines: + self.output_handle(i) + + +class CaptureOutputToFile(object): + """Not displayed in API Docs. + + Capture the output to file. + """ + + def __init__(self, tmp_file_path, stream=sys.stderr): + """Open a temporary file.""" + self.orig_stream_fileno = stream.fileno() + self.tmp_file = open(tmp_file_path, "w") + + def __enter__(self): + """Duplicate the file descriptor to the stream.""" + self.orig_stream_dup = os.dup(self.orig_stream_fileno) + os.dup2(self.tmp_file.fileno(), self.orig_stream_fileno) + + def __exit__(self, type, value, traceback): + """Duplicate the stream descriptor to the file.""" + os.close(self.orig_stream_fileno) + os.dup2(self.orig_stream_dup, self.orig_stream_fileno) + os.close(self.orig_stream_dup) + self.tmp_file.close() + + +class LazyImport(object): + """Lazy import python module till use.""" + + def __init__(self, module_name): + """Init LazyImport object. + + Args: + module_name (string): The name of module imported later + """ + self.module_name = module_name + self.module = None + + def __getattr__(self, name): + """Get the attributes of the module by name.""" + try: + self.module = importlib.import_module(self.module_name) + mod = getattr(self.module, name) + except: + spec = importlib.util.find_spec(str(self.module_name + "." + name)) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + + def __call__(self, *args, **kwargs): + """Call the function in that module.""" + function_name = self.module_name.split(".")[-1] + module_name = self.module_name.split(f".{function_name}")[0] + self.module = importlib.import_module(module_name) + function = getattr(self.module, function_name) + return function(*args, **kwargs) diff --git a/requirements_tf.txt b/requirements_tf.txt index ed3e55f62f0..5fbd34f6ae7 100644 --- a/requirements_tf.txt +++ b/requirements_tf.txt @@ -1,3 +1,5 @@ -intel-extension-for-tensorflow[cpu] +prettytable +psutil +py-cpuinfo pyyaml tensorflow diff --git a/test/3x/tensorflow/keras/requirements.txt b/test/3x/tensorflow/keras/requirements.txt new file mode 100644 index 00000000000..2b9ebaa4d0e --- /dev/null +++ b/test/3x/tensorflow/keras/requirements.txt @@ -0,0 +1 @@ +intel-extension-for-tensorflow[cpu] diff --git a/test/3x/tensorflow/test_config.py b/test/3x/tensorflow/keras/test_config.py similarity index 98% rename from test/3x/tensorflow/test_config.py rename to test/3x/tensorflow/keras/test_config.py index 6a7bd7afeab..8d7a0dcc340 100644 --- a/test/3x/tensorflow/test_config.py +++ b/test/3x/tensorflow/keras/test_config.py @@ -318,10 +318,9 @@ def test_expand_config(self): def test_config_set_api(self): # *Note: this test is only for improving the code coverage and can be removed once the test_common is enabled. from neural_compressor.common.base_config import config_registry, get_all_config_set_from_config_registry - from neural_compressor.tensorflow.quantization.config import FRAMEWORK_NAME - config_set = get_all_config_set_from_config_registry(fwk_name=FRAMEWORK_NAME) - self.assertEqual(len(config_set), len(config_registry.registered_configs[FRAMEWORK_NAME])) + config_set = get_all_config_set_from_config_registry(fwk_name="keras") + self.assertEqual(len(config_set), len(config_registry.registered_configs["keras"])) if __name__ == "__main__": diff --git a/test/3x/tensorflow/quantization/smooth_quant/test_smooth_quant.py b/test/3x/tensorflow/quantization/smooth_quant/test_smooth_quant.py new file mode 100644 index 00000000000..19e06c63e59 --- /dev/null +++ b/test/3x/tensorflow/quantization/smooth_quant/test_smooth_quant.py @@ -0,0 +1,196 @@ +import math +import unittest + +import numpy as np +import tensorflow as tf +from tensorflow.compat.v1 import graph_util + +from neural_compressor.common import set_random_seed +from neural_compressor.tensorflow import SmoothQuantConfig, get_default_sq_config, quantize_model +from neural_compressor.tensorflow.utils import DummyDataset, disable_random + + +def build_conv_graph(): + tf.compat.v1.disable_eager_execution() + x = tf.compat.v1.placeholder(tf.float32, [1, 56, 56, 16], name="input") + top_relu = tf.nn.relu(x) + paddings = tf.constant([[0, 0], [1, 1], [1, 1], [0, 0]]) + x_pad = tf.pad(top_relu, paddings, "CONSTANT") + conv_weights = tf.compat.v1.get_variable( + "weight", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() + ) + conv = tf.nn.conv2d(x_pad, conv_weights, strides=[1, 2, 2, 1], padding="VALID") + normed = tf.compat.v1.layers.batch_normalization(conv) + + conv_weights2 = tf.compat.v1.get_variable( + "weight2", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() + ) + conv2 = tf.nn.conv2d(top_relu, conv_weights2, strides=[1, 2, 2, 1], padding="SAME") + normed2 = tf.compat.v1.layers.batch_normalization(conv2) + add = tf.raw_ops.Add(x=normed, y=normed2, name="addv2") + relu = tf.nn.relu(add) + relu6 = tf.nn.relu6(relu, name="op_to_store") + + out_name = relu6.name.split(":")[0] + with tf.compat.v1.Session() as sess: + sess.run(tf.compat.v1.global_variables_initializer()) + output_graph_def = graph_util.convert_variables_to_constants( + sess=sess, input_graph_def=sess.graph_def, output_node_names=[out_name] + ) + return output_graph_def + + +class MyDataLoader: + def __init__(self, dataset, batch_size=1): + self.dataset = dataset + self.batch_size = batch_size + self.length = math.ceil(len(dataset) / self.batch_size) + + def __iter__(self): + images_list = [] + labels_list = [] + for _, (images, labels) in enumerate(self.dataset): + images = np.expand_dims(images, axis=0) + labels = np.expand_dims(labels, axis=0) + images_list.append(images[0]) + labels_list.append(labels[0]) + if self.batch_size == len(images_list): + yield (images_list, labels_list) + images_list = [] + labels_list = [] + + def __len__(self): + return self.length + + +class TestSmoothQuantTF3xNewApi(unittest.TestCase): + @classmethod + def setUpClass(self): + self.conv_graph = build_conv_graph() + + @classmethod + def tearDownClass(self): + pass + + def test_conv(self): + set_random_seed(9527) + quant_config = SmoothQuantConfig(alpha=0.5) + dataset = DummyDataset(shape=(100, 56, 56, 16), label=True) + calib_dataloader = MyDataLoader(dataset=dataset, batch_size=1) + q_model = quantize_model(self.conv_graph, quant_config, calib_dataloader, calib_iteration=500) + + mul_count = 0 + for i in q_model.graph_def.node: + if i.op == "Mul": + mul_count += 1 + + self.assertEqual(mul_count, 2) + + def test_sq_from_class_beginner(self): + set_random_seed(9527) + quant_config = get_default_sq_config() + dataset = DummyDataset(shape=(100, 56, 56, 16), label=True) + calib_dataloader = MyDataLoader(dataset=dataset, batch_size=1) + q_model = quantize_model(self.conv_graph, quant_config, calib_dataloader, calib_iteration=500) + + mul_count = 0 + for i in q_model.graph_def.node: + if i.op == "Mul": + mul_count += 1 + + self.assertEqual(mul_count, 2) + + def test_sq_from_dict_beginner(self): + quant_config = { + "smooth_quant": { + "global": { + "alpha": 0.5, + }, + } + } + dataset = DummyDataset(shape=(100, 56, 56, 16), label=True) + calib_dataloader = MyDataLoader(dataset=dataset, batch_size=1) + q_model = quantize_model(self.conv_graph, quant_config, calib_dataloader, calib_iteration=500) + + mul_count = 0 + for i in q_model.graph_def.node: + if i.op == "Mul": + mul_count += 1 + + self.assertEqual(mul_count, 2) + + @disable_random() + def test_matmul(self): + x_data = np.random.rand(1024, 1024).astype(np.float32) + y_data = np.random.rand(1024, 1024).astype(np.float32) + import tensorflow.compat.v1 as tf + + x = tf.placeholder(tf.float32, shape=[1024, 1024], name="x") + y = tf.constant(y_data, dtype=tf.float32, shape=[1024, 1024]) + z = tf.matmul(x, y) + bias = np.random.rand(1024).astype(np.float32) + z = tf.nn.bias_add(z, bias) + z = tf.nn.relu(z, name="op_to_store") + + with tf.Session() as sess: + sess.run(z, feed_dict={x: x_data, y: y_data}) + output_graph_def = sess.graph.as_graph_def() + + set_random_seed(9527) + quant_config = SmoothQuantConfig(alpha=0.5) + dataset = DummyDataset(shape=(1024, 1024), label=True) + calib_dataloader = MyDataLoader(dataset=dataset, batch_size=1024) + q_model = quantize_model(output_graph_def, quant_config, calib_dataloader, calib_iteration=1) + + mul_count = 0 + for i in q_model.graph_def.node: + if i.op == "Mul": + mul_count += 1 + + self.assertEqual(mul_count, 1) + + @disable_random() + def test_conv_matmul(self): + x = tf.compat.v1.placeholder(tf.float32, [1, 56, 56, 16], name="input") + top_relu = tf.nn.relu(x) + paddings = tf.constant([[0, 0], [1, 1], [1, 1], [0, 0]]) + x_pad = tf.pad(top_relu, paddings, "CONSTANT") + conv1_weights = tf.compat.v1.get_variable( + "weight_conv1", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() + ) + conv1 = tf.nn.conv2d(x_pad, conv1_weights, strides=[1, 2, 2, 1], padding="VALID") + matmul_weights = tf.compat.v1.get_variable( + "weight_matmul", [28 * 28 * 16, 7 * 7 * 32], initializer=tf.compat.v1.random_normal_initializer() + ) + conv1_reshaped = tf.reshape(conv1, shape=[-1, 28 * 28 * 16]) + matmul = tf.matmul(conv1_reshaped, matmul_weights) + reshape = tf.reshape(matmul, (1, 7, 7, 32)) + conv2_weights = tf.compat.v1.get_variable( + "weight_conv2", [7, 7, 32, 1], initializer=tf.compat.v1.random_normal_initializer() + ) + conv2 = tf.nn.conv2d(reshape, conv2_weights, strides=[1, 2, 2, 1], padding="VALID") + leaky_relu = tf.nn.leaky_relu(conv2, name="op_to_store") + + out_name = leaky_relu.name.split(":")[0] + with tf.compat.v1.Session() as sess: + sess.run(tf.compat.v1.global_variables_initializer()) + output_graph_def = graph_util.convert_variables_to_constants( + sess=sess, input_graph_def=sess.graph_def, output_node_names=[out_name] + ) + + set_random_seed(9527) + quant_config = SmoothQuantConfig(alpha=0.6) + dataset = DummyDataset(shape=(100, 56, 56, 16), label=True) + calib_dataloader = MyDataLoader(dataset=dataset, batch_size=1) + q_model = quantize_model(output_graph_def, quant_config, calib_dataloader, calib_iteration=500) + + mul_count = 0 + for i in q_model.graph_def.node: + if i.op == "Mul": + mul_count += 1 + + self.assertEqual(mul_count, 3) + + +if __name__ == "__main__": + unittest.main()