From aafb6280768719390d62bb9c3ed99ee6683c762f Mon Sep 17 00:00:00 2001 From: zehao-intel Date: Tue, 5 Mar 2024 11:52:31 +0800 Subject: [PATCH 01/13] refacor keras Signed-off-by: zehao-intel --- .../algorithms/static_quant/keras.py | 350 ++++++++---------- 1 file changed, 162 insertions(+), 188 deletions(-) diff --git a/neural_compressor/tensorflow/algorithms/static_quant/keras.py b/neural_compressor/tensorflow/algorithms/static_quant/keras.py index 442ba1b3d48..a963127bdc5 100644 --- a/neural_compressor/tensorflow/algorithms/static_quant/keras.py +++ b/neural_compressor/tensorflow/algorithms/static_quant/keras.py @@ -16,8 +16,8 @@ # limitations under the License. import copy -import json import math +import re import os from collections import OrderedDict, UserDict from typing import Callable, Dict @@ -28,6 +28,8 @@ import yaml from neural_compressor.common import logger +from neural_compressor.common.utils import DEFAULT_WORKSPACE + from neural_compressor.tensorflow.keras.layers import ( DeQuantize, FakeQuant, @@ -91,6 +93,10 @@ def __init__(self, framework_specific_info): self.callbacks = [] self.conv_format = {} + self.layer_name_mapping = {} + self.input_layer_dict = {} + self.custom_layers = _add_supported_quantized_objects({}) + self.tmp_dir = DEFAULT_WORKSPACE + "/tmp_saved_model_dir" def _check_itex(self): """Check if the Intel® Extension for TensorFlow has been installed.""" @@ -154,32 +160,22 @@ def _pre_optimize(self, model): def _check_quantize_format(self, model): """The function that checks format for conv ops.""" - json_model = copy.deepcopy(json.loads(model.to_json())) - config = json_model["config"] - fp32_layers = config["layers"] - name_op_map = {} - - for idx, layer in enumerate(copy.deepcopy(fp32_layers)): - name_op_map[layer["config"]["name"]] = layer - - for idx, layer in enumerate(copy.deepcopy(fp32_layers)): - layer_config = layer["config"] - if layer["class_name"] in self.supported_op: - if "inbound_nodes" in layer: - check_layer = name_op_map[layer["inbound_nodes"][0][0][0]] - else: - check_layer = fp32_layers[idx - 1] - if check_layer["class_name"] in ["Activation"] and check_layer["config"]["activation"] in ["relu"]: - self.conv_format[layer["config"]["name"]] = "u8" - else: - self.conv_format[layer["config"]["name"]] = "s8" + for layer in model.layers: + if layer.__class__.__name__ in self.supported_op: + self.conv_format[layer.name] = "s8" + input_layer_names = self.input_layer_dict[layer.name] + for input_layer_name in input_layer_names: + check_layer = self.layer_name_mapping[input_layer_name] + if check_layer.__class__.__name__ == "Activation" \ + and check_layer.activation.__name__ in ["relu"]: + self.conv_format[layer.name] = "u8" + break + return model def _fuse_bn(self, model): """Fusing Batch Normalization.""" - json_model = copy.deepcopy(json.loads(model.to_json())) - config = json_model["config"] - fp32_layers = config["layers"] + fp32_layers = model.layers def fuse_conv_bn(conv_weight, bn_weight, conv_type="Conv2D", eps=1.0e-5): assert conv_type in [ @@ -225,60 +221,56 @@ def fuse_conv_bn(conv_weight, bn_weight, conv_type="Conv2D", eps=1.0e-5): bias = bias.reshape(-1) return [depth_weight, weight, bias] if conv_type == "SeparableConv2D" else [weight, bias] - node_map = {} - for idx, layer in enumerate(copy.deepcopy(fp32_layers)): - layer_config = layer["config"] - if "inbound_nodes" in layer: - node_map[layer["name"]] = layer - fuse_layers = [] fold_conv = [] - for idx, layer in enumerate(copy.deepcopy(fp32_layers)): - layer_config = layer["config"] - if "inbound_nodes" in layer: - if layer["class_name"] in ["BatchNormalization"]: - bn_inbound_node = node_map[layer_config["name"]]["inbound_nodes"][0][0] - if bn_inbound_node[0] in self.conv_weights.keys(): - conv_weight = self.conv_weights[bn_inbound_node[0]] - conv_layer = node_map[bn_inbound_node[0]] - bn_weight = self.bn_weights[layer_config["name"]] - self.layer_weights[bn_inbound_node[0]] = fuse_conv_bn( - conv_weight, bn_weight, conv_layer["class_name"], layer["config"]["epsilon"] + for idx, layer in enumerate(fp32_layers): + if hasattr(layer, "inbound_nodes"): + if layer.__class__.__name__ in ("BatchNormalization"): + bn_inbound_node = layer.inbound_nodes[0] + inbound_layer = bn_inbound_node.inbound_layers + if inbound_layer.name in self.conv_weights.keys(): + conv_layer = inbound_layer + conv_weight = self.conv_weights[conv_layer.name] + bn_weight = self.bn_weights[layer.name] + self.layer_weights[conv_layer.name] = fuse_conv_bn( + conv_weight, bn_weight, conv_layer.__class__.__name__, layer.epsilon ) - fold_conv.append(bn_inbound_node[0]) + fold_conv.append(conv_layer.name) else: fuse_layers.append(layer) - elif len(layer["inbound_nodes"]): + elif len(layer.inbound_nodes): new_bound_nodes = [] # OpLambda node will have different bound node - if layer["class_name"] in ["TFOpLambda", "SlicingOpLambda"]: + if layer.__class__.__name__ in ("TFOpLambda", "SlicingOpLambda"): fuse_layers.append(layer) else: - for bound_node in layer["inbound_nodes"][0]: - if bound_node[0] in self.bn_weights.keys(): - bn_inbound_node = node_map[bound_node[0]]["inbound_nodes"][0][0] - if bn_inbound_node[0] in self.conv_weights.keys(): + for bound_node in layer.inbound_nodes[0]: + inbound_layer = bound_node.inbound_layers + if inbound_layer in self.bn_weights.keys(): + bn_inbound_node = inbound_layer.inbound_nodes[0] + bn_inbound_layer = bn_inbound_node.inbound_layers + if bn_inbound_layer.name in self.conv_weights.keys(): new_bound_nodes.append(bn_inbound_node) else: new_bound_nodes.append(bound_node) else: new_bound_nodes.append(bound_node) - layer["inbound_nodes"] = [new_bound_nodes] + layer.inbound_nodes = new_bound_nodes fuse_layers.append(layer) else: fuse_layers.append(layer) else: if ( idx > 0 - and layer["class_name"] in ["BatchNormalization"] - and fp32_layers[idx - 1]["class_name"] in ["Conv2D"] + and layer.__class__.__name__ == "BatchNormalization" + and fp32_layers[idx - 1].__class__.__name__ == "Conv2D" ): - conv_name = fp32_layers[idx - 1]["config"]["name"] + conv_name = fp32_layers[idx - 1].name conv_weight = self.conv_weights[conv_name] - bn_weight = self.bn_weights[layer_config["name"]] - conv_type = fp32_layers[idx - 1]["class_name"] + bn_weight = self.bn_weights[layer.name] + conv_type = fp32_layers[idx - 1].__class__.__name__ self.layer_weights[conv_name] = fuse_conv_bn( - conv_weight, bn_weight, conv_type, layer["config"]["epsilon"] + conv_weight, bn_weight, conv_type, layer.epsilon ) fold_conv.append(conv_name) else: @@ -286,15 +278,13 @@ def fuse_conv_bn(conv_weight, bn_weight, conv_type="Conv2D", eps=1.0e-5): # bn folding will have a shift bias for idx, layer in enumerate(fuse_layers): - layer_config = layer["config"] if ( - layer["class_name"] in ["Conv2D", "DepthwiseConv2D", "SeparableConv2D"] - and layer_config["name"] in fold_conv + layer.__class__.__name__ in ("Conv2D", "DepthwiseConv2D", "SeparableConv2D") + and layer.name in fold_conv ): - layer_config["use_bias"] = True + layer.use_bias = True - json_model["config"]["layers"] = fuse_layers - fused_model = self._restore_model_from_json(json_model) + fused_model = self._rebuild_model_from_layers(model, fuse_layers) return fused_model @dump_elapsed_time("Pass quantize model") @@ -318,8 +308,8 @@ def quantize(self, quant_config, model, dataloader, iteration, q_func=None): converted_model = self.convert_bf16() return converted_model - # if self.backend == "itex": - # self._check_itex() + if self.backend == "itex": + self._check_itex() logger.debug("Dump quantization configurations:") logger.debug(self.quantize_config) calib_sampling_size = tune_cfg.get("calib_sampling_size", 1) @@ -337,23 +327,20 @@ def quantize(self, quant_config, model, dataloader, iteration, q_func=None): q_layers = [] self.inbound_nodes_map = {} for idx, layer in enumerate(copy.deepcopy(self.fp32_layers)): - layer_config = layer["config"] if ( - layer["class_name"] in self.supported_op - and layer["config"]["name"] in self.quantize_config["op_wise_config"] + layer.__class__.__name__ in self.supported_op + and layer.name in self.quantize_config["op_wise_config"] ): - op_config = self.quantize_config["op_wise_config"][layer["config"]["name"]] + op_config = self.quantize_config["op_wise_config"][layer.name] mode = "per_channel" if op_config[0] else "per_tensor" fake_q_name = "fake_quant_" + str(idx) - fake_q_layer = { - "class_name": "FakeQuant", - "name": fake_q_name, - "T": self.conv_format[layer["config"]["name"]], - "config": {"mode": "per_tensor", "name": fake_q_name}, - } - if "inbound_nodes" in layer: - fake_q_layer["inbound_nodes"] = layer["inbound_nodes"] - layer["inbound_nodes"] = [[[fake_q_name, 0, 0, {}]]] + fake_q_layer = FakeQuant(name=fake_q_name, + T=self.conv_format[layer.name], + mode="per_tensor" + ) + if hasattr(layer, "inbound_nodes"): + fake_q_layer.inbound_nodes[0].inbound_layers = layer.inbound_nodes[0].inbound_layers + layer.inbound_nodes[0].inbound_layers = fake_q_layer self.inbound_nodes_map[fake_q_name] = layer q_layers.append(fake_q_layer) @@ -361,10 +348,7 @@ def quantize(self, quant_config, model, dataloader, iteration, q_func=None): else: q_layers.append(layer) - json_model = copy.deepcopy(json.loads(self.pre_optimized_object.to_json())) - json_model["config"]["layers"] = q_layers - quantized_model = self._restore_model_from_json(json_model) - + quantized_model = self._rebuild_model_from_layers(self.pre_optimized_object, q_layers) converted_model = self._calibrate(quantized_model, dataloader, self.quantize_config["calib_iteration"]) return converted_model @@ -376,141 +360,133 @@ def _calibrate(self, model, dataloader, calib_interation): results = {} for idx, (inputs, labels) in enumerate(dataloader): outputs = model.predict_on_batch(inputs) - json_model = copy.deepcopy(json.loads(model.to_json())) - config = json_model["config"] - layers = config["layers"] - for layer in layers: - if layer["class_name"] == "FakeQuant": - min_value = layer["config"]["min_value"] - max_value = layer["config"]["max_value"] - if layer["config"]["name"] not in results: - results[layer["config"]["name"]] = {"min": [min_value], "max": [max_value]} + for layer in model.layers: + if layer.__class__.__name__ == "FakeQuant": + min_value = layer.min_value + max_value = layer.max_value + if layer.name not in results: + results[layer.name] = {"min": [min_value], "max": [max_value]} else: - results[layer["config"]["name"]]["min"].append(min_value) - results[layer["config"]["name"]]["max"].append(max_value) + results[layer.name]["min"].append(min_value) + results[layer.name]["max"].append(max_value) if idx + 1 == calib_interation: break - # insert the calibrated min/max to Q/DQ - json_model = copy.deepcopy(json.loads(model.to_json())) - config = json_model["config"] - layers = config["layers"] q_layers = [] - # quantize_mode = self._check_quantize_mode(json_model) inbound_reverse_map = {} - for idx, layer in enumerate(layers): - layer_config = copy.deepcopy(layer["config"]) - if layer["class_name"] == "FakeQuant": - min_value = min(results[layer["config"]["name"]]["min"]) - max_value = max(results[layer["config"]["name"]]["max"]) - quantize_layer = { - "class_name": "Quantize", - "name": "quantize_" + str(idx), - "config": { - "min_range": min_value, - "max_range": max_value, - "T": layer_config["T"], - "name": "quantize_" + str(idx), - }, - } - dequantize_layer = { - "class_name": "DeQuantize", - "name": "dequantize_" + str(idx), - "config": { - "min_range": min_value, - "max_range": max_value, - # 'mode': quantize_mode, - "name": "dequantize_" + str(idx), - }, - } - if "inbound_nodes" in layer: - quantize_layer["inbound_nodes"] = layer["inbound_nodes"] - dequantize_layer["inbound_nodes"] = [[["quantize_" + str(idx), 0, 0, {}]]] + for idx, layer in enumerate(model.layers): + if layer.__class__.__name__ == "FakeQuant": + min_value = min(results[layer.name]["min"]) + max_value = max(results[layer.name]["max"]) + quantize_layer = Quantize( + name = "quantize_" + str(idx), + min_range = min_value, + max_range = max_value, + T = layer.T, + ) + dequantize_layer = DeQuantize( + name = "dequantize_" + str(idx), + min_range = min_value, + max_range = max_value, + ) + + if hasattr(layer, "inbound_nodes"): + quantize_layer.inbound_nodes = layer.inbound_nodes + dequantize_layer.inbound_nodes[0].inbound_layers = quantize_layer # find the conv/dense layer from fake quant map and # change the conv/dense node inbound to dequantize - layer_name = self.inbound_nodes_map[layer["name"]]["name"] - inbound_reverse_map[layer_name] = [[["dequantize_" + str(idx), 0, 0, {}]]] + layer_name = self.inbound_nodes_map[layer.name].name + inbound_reverse_map[layer_name] = dequantize_layer q_layers.append(quantize_layer) q_layers.append(dequantize_layer) elif ( - layer["class_name"] in self.supported_op - and layer["config"]["name"] in self.quantize_config["op_wise_config"] + layer.__class__.__name__ in self.supported_op + and layer.name in self.quantize_config["op_wise_config"] ): # index 0 is weight, index 1 is bias - q_layer_name = "Q" + layer["class_name"] + q_layer_class = "Q" + layer.__class__.__name__ # this is for inbounds search - q_name = layer["config"]["name"] + q_name = layer.name # for layers that have weights - if layer["config"]["name"] in self.layer_weights: - kernel = self.layer_weights[layer["config"]["name"]][0] + if layer.name in self.layer_weights: + kernel = self.layer_weights[layer.name][0] dim = list(range(0, kernel.ndim)) t_dim = [dim.pop(-1)] t_dim.extend(dim) channel_size = kernel.shape[-1] kernel_channel = kernel.transpose(t_dim).reshape(channel_size, -1) - layer_config["min_value"] = json.dumps(np.min(kernel_channel, axis=1).tolist()) - layer_config["max_value"] = json.dumps(np.max(kernel_channel, axis=1).tolist()) + layer.min_value = np.min(kernel_channel, axis=1).tolist() + layer.max_value = np.max(kernel_channel, axis=1).tolist() else: # default value, but never expected to be used # cause no kernel weights for this layer - layer_config["min_value"] = json.dumps([-10000]) - layer_config["max_value"] = json.dumps([10000]) + layer.min_value = [-10000] + layer.max_value = [10000] + layer.name = q_name + layer_config = layer.get_config() layer_config["name"] = q_name - q_layer = {"class_name": q_layer_name, "name": q_name, "config": layer_config} - if "inbound_nodes" in layer: - q_layer["inbound_nodes"] = inbound_reverse_map[layer["name"]] + q_layer = self.custom_layers[q_layer_class].from_config(layer_config) + if hasattr(layer, "inbound_nodes"): + q_layer.inbound_nodes = inbound_reverse_map[layer.name] q_layers.append(q_layer) else: q_layers.append(layer) - json_model["config"]["layers"] = q_layers - quantized_model = self._restore_model_from_json(json_model) + quantized_model = self._rebuild_model_from_layers(model ,q_layers) return quantized_model def convert_bf16(self): """Execute the BF16 conversion.""" tf.keras.mixed_precision.set_global_policy("mixed_bfloat16") - json_model = copy.deepcopy(json.loads(self.pre_optimized_object.to_json())) + model = self.pre_optimized_object - for layer in json_model["config"]["layers"]: - if layer["config"]["name"] in self.bf16_ops: - layer["config"]["dtype"] = "mixed_bfloat16" + for layer in model.layers: + if layer.name in self.bf16_ops: + layer.dtype = "mixed_bfloat16" - converted_model = self._restore_model_from_json(json_model) + model.save(self.tmp_dir) + converted_model = tf.keras.saved_model.load(self.tmp_dir) tf.keras.mixed_precision.set_global_policy("float32") return converted_model # (TODO) choose the properly quantize mode - def _check_quantize_mode(self, json_model): + def _check_quantize_mode(self, model): """Check what quantize mode to use.""" - config = json_model["config"] - layers = config["layers"] - for idx, layer in enumerate(layers): - if "ReLU" in layer["class_name"]: + for idx, layer in enumerate(model.layers): + if "ReLU" in layer.__class__.__name__: return "MIN_FIRST" return "SCALED" - def _restore_model_from_json(self, json_model): - """Generate a keras model from json files.""" - from tensorflow.keras.models import model_from_json - - from neural_compressor.tensorflow.utils import version1_gte_version2 - - if version1_gte_version2(keras.__version__, "2.13.1"): - from keras.src.saving import serialization_lib - - serialization_lib.enable_unsafe_deserialization() - - custom_objects = {} - # We need to keep a dictionary of custom objects as our quantized library - # is not recognized by keras. - custom_objects = _add_supported_quantized_objects(custom_objects) - json_model_file = json.dumps(json_model) - qmodel = model_from_json(json_model_file, custom_objects=custom_objects) - qmodel = self._set_weights(qmodel, self.layer_weights) - return qmodel + def _rebuild_model_from_layers(self, model, layers): + """Insert a layer before the target layer.""" + model_outputs = [] + input_layer_dict = {} + output_tensor_dict = {layers[0].name: model.input} + + for layer in layers: + for node in layer._outbound_nodes: + layer_name = node.outbound_layer.name + if layer_name not in input_layer_dict: + input_layer_dict[layer_name] = [layer.name] + else: + input_layer_dict[layer_name].append(layer.name) + + for layer in layers[1:]: + input_tensors = [output_tensor_dict[input_layer] + for input_layer in input_layer_dict[layer.name]] + if len(input_tensors) == 1: + input_tensors = input_tensors[0] + x = layer(input_tensors) + output_tensor_dict[layer.name] = x + if layer_name in model.output_names: + model_outputs.append(x) + + new_model = tf.keras.models.Model(inputs=model.inputs, outputs=model_outputs) + new_model = self._set_weights(new_model, self.layer_weights) + new_model.save(self.tmp_dir) + return tf.keras.saved_model.load(self.tmp_dir) # set fp32 weights to qmodel def _set_weights(self, qmodel, layer_weights): @@ -622,28 +598,26 @@ def query_fw_capability(self, model): self.layer_weights[layer.name] = copy.deepcopy(layer.get_weights()) self.pre_optimized_object = self._pre_optimize(keras_object) - json_model = copy.deepcopy(json.loads(self.pre_optimized_object.to_json())) - config = json_model["config"] - self.fp32_layers = config["layers"] - - quantizable_op_details = OrderedDict() - for details in self.fp32_layers: - node_op = details["class_name"] - node_name = details["config"]["name"] - if node_op == "Conv2D": - quantizable_op_details[(node_name, node_op)] = [conv_config, bf16_config, fp32_config] - elif node_op == "Dense": - quantizable_op_details[(node_name, node_op)] = [dense_config, bf16_config, fp32_config] - elif node_op in {"AveragePooling2D", "AvgPool2D"}: - quantizable_op_details[(node_name, node_op)] = [avgpool_config, bf16_config, fp32_config] - elif node_op in {"MaxPooling2D", "MaxPool2D"}: - quantizable_op_details[(node_name, node_op)] = [maxpool_config, bf16_config, fp32_config] + + self.fp32_layers = keras_object.layers + + quantizable_layer_details = OrderedDict() + for layer in self.fp32_layers: + layer_class = layer.__class__.__name__ + if layer_class == "Conv2D": + quantizable_layer_details[(layer.name, layer_class)] = [conv_config, bf16_config, fp32_config] + elif layer_class == "Dense": + quantizable_layer_details[(layer.name, layer_class)] = [dense_config, bf16_config, fp32_config] + elif layer_class in {"AveragePooling2D", "AvgPool2D"}: + quantizable_layer_details[(layer.name, layer_class)] = [avgpool_config, bf16_config, fp32_config] + elif layer_class in {"MaxPooling2D", "MaxPool2D"}: + quantizable_layer_details[(layer.name, layer_class)] = [maxpool_config, bf16_config, fp32_config] else: - quantizable_op_details[(node_name, node_op)] = [bf16_config, fp32_config] + quantizable_layer_details[(layer.name, layer_class)] = [bf16_config, fp32_config] capability = { - "opwise": copy.deepcopy(quantizable_op_details), - "optypewise": self.get_optype_wise_ability(quantizable_op_details), + "opwise": copy.deepcopy(quantizable_layer_details), + "optypewise": self.get_optype_wise_ability(quantizable_layer_details), } logger.debug("Dump framework quantization capability:") logger.debug(capability) From dd4b1b9fb300bc567774f08848954e3d8c58f50e Mon Sep 17 00:00:00 2001 From: zehao-intel Date: Tue, 26 Mar 2024 15:52:58 +0800 Subject: [PATCH 02/13] support tf2.16.1 Signed-off-by: zehao-intel --- .../algorithms/static_quant/keras.py | 557 +++++++++++------- .../tensorflow/keras/layers/__init__.py | 3 +- .../tensorflow/keras/layers/conv2d.py | 89 ++- .../tensorflow/keras/layers/dense.py | 70 ++- .../keras/layers/depthwise_conv2d.py | 374 ++++++++---- .../keras/layers/layer_initializer.py | 33 ++ .../tensorflow/keras/layers/pool2d.py | 89 ++- .../tensorflow/keras/layers/quantizer.py | 19 +- .../keras/layers/separable_conv2d.py | 371 +++++++++--- 9 files changed, 1160 insertions(+), 445 deletions(-) create mode 100644 neural_compressor/tensorflow/keras/layers/layer_initializer.py diff --git a/neural_compressor/tensorflow/algorithms/static_quant/keras.py b/neural_compressor/tensorflow/algorithms/static_quant/keras.py index a963127bdc5..1a76023f6f7 100644 --- a/neural_compressor/tensorflow/algorithms/static_quant/keras.py +++ b/neural_compressor/tensorflow/algorithms/static_quant/keras.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- # -# Copyright (c) 2023 Intel Corporation +# Copyright (c) 2024 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -16,13 +16,11 @@ # limitations under the License. import copy -import math -import re +import json import os from collections import OrderedDict, UserDict from typing import Callable, Dict -import keras import numpy as np import tensorflow as tf import yaml @@ -42,36 +40,13 @@ Quantize, ) from neural_compressor.tensorflow.quantization.config import StaticQuantConfig -from neural_compressor.tensorflow.utils import deep_get, dump_elapsed_time - - -def _add_supported_quantized_objects(custom_objects): - """Map all the quantized objects.""" - custom_objects["Quantize"] = Quantize - custom_objects["DeQuantize"] = DeQuantize - custom_objects["FakeQuant"] = FakeQuant - custom_objects["QConv2D"] = QConv2D - custom_objects["QDepthwiseConv2D"] = QDepthwiseConv2D - custom_objects["QSeparableConv2D"] = QSeparableConv2D - custom_objects["QDense"] = QDense - custom_objects["QMaxPool2D"] = QMaxPool2D - custom_objects["QAvgPool2D"] = QAvgPool2D - custom_objects["QMaxPooling2D"] = QMaxPool2D - custom_objects["QAveragePooling2D"] = QAvgPool2D - return custom_objects +from neural_compressor.tensorflow.utils import deep_get, dump_elapsed_time, version1_gte_version2 class KerasAdaptor: """The keras class of framework adaptor layer.""" - def __init__(self, framework_specific_info): - self.framework_specific_info = framework_specific_info - self.approach = deep_get(self.framework_specific_info, "approach", False) - self.quantize_config = {"op_wise_config": {}} - self.device = self.framework_specific_info["device"] - self.backend = self.framework_specific_info["backend"] - self.recipes = deep_get(self.framework_specific_info, "recipes", {}) - self.supported_op = [ + supported_op = [ "Conv2D", "Dense", "SeparableConv2D", @@ -81,8 +56,31 @@ def __init__(self, framework_specific_info): "AvgPool2D", "MaxPool2D", ] + + custom_layers = { + "Quantize": Quantize, + "DeQuantize": DeQuantize, + "FakeQuant": FakeQuant, + "QConv2D": QConv2D, + "QDepthwiseConv2D": QDepthwiseConv2D, + "QSeparableConv2D": QSeparableConv2D, + "QDense": QDense, + "QMaxPool2D": QMaxPool2D, + "QAvgPool2D": QAvgPool2D, + "QMaxPooling2D": QMaxPool2D, + "QAveragePooling2D": QAvgPool2D, + } + + def __init__(self, framework_specific_info): + """Initialize the KerasAdaptor class with framework specific information.""" + self.framework_specific_info = framework_specific_info + self.approach = deep_get(self.framework_specific_info, "approach", False) + self.quantize_config = {"op_wise_config": {}} + self.device = self.framework_specific_info["device"] + self.backend = self.framework_specific_info["backend"] + self.recipes = deep_get(self.framework_specific_info, "recipes", {}) - self.pre_optimized_object = None + self.pre_optimized_model = None self.pre_optimizer_handle = None self.bf16_ops = [] self.fp32_ops = [] @@ -93,10 +91,9 @@ def __init__(self, framework_specific_info): self.callbacks = [] self.conv_format = {} - self.layer_name_mapping = {} - self.input_layer_dict = {} - self.custom_layers = _add_supported_quantized_objects({}) - self.tmp_dir = DEFAULT_WORKSPACE + "/tmp_saved_model_dir" + self.fold_conv = [] + self.tmp_dir = (DEFAULT_WORKSPACE + "/tmp_model.keras") if version1_gte_version2(tf.__version__, "2.16.1") \ + else (DEFAULT_WORKSPACE + "/tmp_model") def _check_itex(self): """Check if the Intel® Extension for TensorFlow has been installed.""" @@ -108,74 +105,74 @@ def _check_itex(self): "Please install it to run models on ITEX backend" ) - def tuning_cfg_to_fw(self, tuning_cfg): - """Parse tune_config and set framework variables.""" - self.quantize_config["calib_iteration"] = tuning_cfg["calib_iteration"] - self.quantize_config["device"] = self.device - self.quantize_config["advance"] = deep_get(tuning_cfg, "advance") - fp32_ops = [] - bf16_ops = [] - bf16_type = set(self.query_handler.get_op_types_by_precision(precision="bf16")) - dispatched_op_names = [j[0] for j in tuning_cfg["op"]] - invalid_op_names = [i for i in self.quantize_config["op_wise_config"] if i not in dispatched_op_names] - - for op_name in invalid_op_names: - self.quantize_config["op_wise_config"].pop(op_name) + def convert_bf16(self): + """Execute the BF16 conversion.""" + tf.keras.mixed_precision.set_global_policy("mixed_bfloat16") + model = self.pre_optimized_model - for each_op_info in tuning_cfg["op"]: - op_name = each_op_info[0] + for layer in model.layers: + if layer.name in self.bf16_ops: + layer.dtype = "mixed_bfloat16" - if tuning_cfg["op"][each_op_info]["activation"]["dtype"] == "bf16": - if each_op_info[1] in bf16_type: - bf16_ops.append(op_name) - continue + model.save(self.tmp_dir) + converted_model = tf.keras.models.load_model(self.tmp_dir) + tf.keras.mixed_precision.set_global_policy("float32") - if tuning_cfg["op"][each_op_info]["activation"]["dtype"] == "fp32": - if op_name in self.quantize_config["op_wise_config"]: - self.quantize_config["op_wise_config"].pop(op_name) - fp32_ops.append(op_name) - continue + return converted_model - is_perchannel = False - bit = None - if "weight" in tuning_cfg["op"][each_op_info]: - is_perchannel = tuning_cfg["op"][each_op_info]["weight"]["granularity"] == "per_channel" - # bit = tuning_cfg['op'][each_op_info]['weight']['bit'] - weight_bit = bit if bit else 7.0 - algorithm = tuning_cfg["op"][each_op_info]["activation"]["algorithm"] - is_asymmetric = False - if "activation" in tuning_cfg["op"][each_op_info]: - is_asymmetric = tuning_cfg["op"][each_op_info]["activation"]["scheme"] == "asym" - self.quantize_config["op_wise_config"][op_name] = (is_perchannel, algorithm, is_asymmetric, weight_bit) - self.bf16_ops = bf16_ops - if self.bf16_ops: - self.bf16_ops.pop(-1) - self.fp32_ops = fp32_ops + # (TODO) choose the properly quantize mode + def _check_quantize_mode(self, model): + """Check what quantize mode to use.""" + for layer in model.layers: + if "ReLU" in layer.__class__.__name__: + return "MIN_FIRST" + return "SCALED" - def _pre_optimize(self, model): - """Apply pre-optimization.""" - model = self._check_quantize_format(model) - model = self._fuse_bn(model) - return model + def _set_weights(self, qmodel, layer_weights): + """Set fp32 weights to qmodel""" + for qlayer in qmodel.layers: + if qlayer.get_weights(): + if qlayer.name in layer_weights: + qlayer.set_weights(layer_weights[qlayer.name]) + else: + hit_layer = False + for sub_layer in qlayer.submodules: + if sub_layer.name in layer_weights: + qlayer.set_weights(layer_weights[sub_layer.name]) + hit_layer = True + break + if not hit_layer: + raise ValueError("Can not match the module weights....") + return qmodel def _check_quantize_format(self, model): """The function that checks format for conv ops.""" + input_layer_dict = {} + layer_name_mapping = {} + for layer in model.layers: + layer_name_mapping[layer.name] = layer + for node in layer._outbound_nodes: + layer_name = node.outbound_layer.name + if layer_name not in input_layer_dict: + input_layer_dict[layer_name] = [layer.name] + else: + input_layer_dict[layer_name].append(layer.name) + for layer in model.layers: if layer.__class__.__name__ in self.supported_op: self.conv_format[layer.name] = "s8" - input_layer_names = self.input_layer_dict[layer.name] + input_layer_names = input_layer_dict[layer.name] for input_layer_name in input_layer_names: - check_layer = self.layer_name_mapping[input_layer_name] + check_layer = layer_name_mapping[input_layer_name] if check_layer.__class__.__name__ == "Activation" \ and check_layer.activation.__name__ in ["relu"]: self.conv_format[layer.name] = "u8" break - return model - def _fuse_bn(self, model): """Fusing Batch Normalization.""" - fp32_layers = model.layers + fuse_bn_model = copy.deepcopy(model) + fp32_layers = fuse_bn_model.layers def fuse_conv_bn(conv_weight, bn_weight, conv_type="Conv2D", eps=1.0e-5): assert conv_type in [ @@ -222,7 +219,6 @@ def fuse_conv_bn(conv_weight, bn_weight, conv_type="Conv2D", eps=1.0e-5): return [depth_weight, weight, bias] if conv_type == "SeparableConv2D" else [weight, bias] fuse_layers = [] - fold_conv = [] for idx, layer in enumerate(fp32_layers): if hasattr(layer, "inbound_nodes"): if layer.__class__.__name__ in ("BatchNormalization"): @@ -232,10 +228,11 @@ def fuse_conv_bn(conv_weight, bn_weight, conv_type="Conv2D", eps=1.0e-5): conv_layer = inbound_layer conv_weight = self.conv_weights[conv_layer.name] bn_weight = self.bn_weights[layer.name] + self.layer_weights[conv_layer.name] = fuse_conv_bn( conv_weight, bn_weight, conv_layer.__class__.__name__, layer.epsilon ) - fold_conv.append(conv_layer.name) + self.fold_conv.append(conv_layer.name) else: fuse_layers.append(layer) elif len(layer.inbound_nodes): @@ -244,9 +241,10 @@ def fuse_conv_bn(conv_weight, bn_weight, conv_type="Conv2D", eps=1.0e-5): if layer.__class__.__name__ in ("TFOpLambda", "SlicingOpLambda"): fuse_layers.append(layer) else: - for bound_node in layer.inbound_nodes[0]: + for bound_node in layer.inbound_nodes: inbound_layer = bound_node.inbound_layers - if inbound_layer in self.bn_weights.keys(): + + if not isinstance(inbound_layer, list) and inbound_layer in self.bn_weights.keys(): bn_inbound_node = inbound_layer.inbound_nodes[0] bn_inbound_layer = bn_inbound_node.inbound_layers if bn_inbound_layer.name in self.conv_weights.keys(): @@ -255,7 +253,10 @@ def fuse_conv_bn(conv_weight, bn_weight, conv_type="Conv2D", eps=1.0e-5): new_bound_nodes.append(bound_node) else: new_bound_nodes.append(bound_node) - layer.inbound_nodes = new_bound_nodes + + for idx, bound_node in enumerate(new_bound_nodes): + layer.inbound_nodes[idx] = new_bound_nodes[idx] + fuse_layers.append(layer) else: fuse_layers.append(layer) @@ -269,23 +270,31 @@ def fuse_conv_bn(conv_weight, bn_weight, conv_type="Conv2D", eps=1.0e-5): conv_weight = self.conv_weights[conv_name] bn_weight = self.bn_weights[layer.name] conv_type = fp32_layers[idx - 1].__class__.__name__ + self.layer_weights[conv_name] = fuse_conv_bn( conv_weight, bn_weight, conv_type, layer.epsilon ) - fold_conv.append(conv_name) + self.fold_conv.append(conv_name) else: fuse_layers.append(layer) - # bn folding will have a shift bias for idx, layer in enumerate(fuse_layers): if ( layer.__class__.__name__ in ("Conv2D", "DepthwiseConv2D", "SeparableConv2D") - and layer.name in fold_conv + and layer.name in self.fold_conv ): - layer.use_bias = True + conv_config = layer.get_config() + conv_config["use_bias"] = True + conv_layer = type(layer).from_config(conv_config) + conv_layer._outbound_nodes.append(layer._outbound_nodes[0]) + fuse_layers[idx] = conv_layer - fused_model = self._rebuild_model_from_layers(model, fuse_layers) - return fused_model + bn_surgery = KerasSurgery(model) + fused_model = bn_surgery.fuse_bn_layers(fuse_layers, self.conv_weights.keys()) + fused_model = self._set_weights(fused_model, self.layer_weights) + fused_model.save(self.tmp_dir) + + return tf.keras.models.load_model(self.tmp_dir) @dump_elapsed_time("Pass quantize model") def quantize(self, quant_config, model, dataloader, iteration, q_func=None): @@ -310,6 +319,7 @@ def quantize(self, quant_config, model, dataloader, iteration, q_func=None): if self.backend == "itex": self._check_itex() + logger.debug("Dump quantization configurations:") logger.debug(self.quantize_config) calib_sampling_size = tune_cfg.get("calib_sampling_size", 1) @@ -324,9 +334,9 @@ def quantize(self, quant_config, model, dataloader, iteration, q_func=None): ) ) - q_layers = [] - self.inbound_nodes_map = {} - for idx, layer in enumerate(copy.deepcopy(self.fp32_layers)): + fq_layers_dict = {} + fq_output_layers = {} + for idx, layer in enumerate(self.pre_optimized_model.layers): if ( layer.__class__.__name__ in self.supported_op and layer.name in self.quantize_config["op_wise_config"] @@ -338,76 +348,87 @@ def quantize(self, quant_config, model, dataloader, iteration, q_func=None): T=self.conv_format[layer.name], mode="per_tensor" ) - if hasattr(layer, "inbound_nodes"): - fake_q_layer.inbound_nodes[0].inbound_layers = layer.inbound_nodes[0].inbound_layers - layer.inbound_nodes[0].inbound_layers = fake_q_layer - self.inbound_nodes_map[fake_q_name] = layer - - q_layers.append(fake_q_layer) - q_layers.append(layer) - else: - q_layers.append(layer) - - quantized_model = self._rebuild_model_from_layers(self.pre_optimized_object, q_layers) - converted_model = self._calibrate(quantized_model, dataloader, self.quantize_config["calib_iteration"]) - - return converted_model + fq_layers_dict[layer.name] = [fake_q_layer] + fq_output_layers[fake_q_layer.name] = layer.name + self.pre_optimized_model.save(self.tmp_dir) + + fq_surgery = KerasSurgery(self.pre_optimized_model) + calibration_model = fq_surgery.insert_quant_layers(fq_layers_dict) + calibration_model = self._set_weights(calibration_model, self.layer_weights) + + quantized_model = self._calibrate(calibration_model, + dataloader, + self.quantize_config["calib_iteration"], + fq_output_layers, + ) + + return quantized_model - def _calibrate(self, model, dataloader, calib_interation): - """Apply calibration.""" + def _calibrate(self, + model, + dataloader, + calib_interation, + fq_output_layers): + """Apply calibration. + + Args: + model (tf.keras.Model): The model inserted with FakeQuant layers for calibration. + dataloader(object): The calibration dataloader used to load quantization dataset. + iteration(int): The iteration of calibration. + fq_output_layers (dict): A dict mapping from names of FakeQuant layers to + names of their output layers. + """ # run eagerly to fetch the numpy min/max - model.compile(run_eagerly=True) results = {} + model.compile(run_eagerly=True) for idx, (inputs, labels) in enumerate(dataloader): - outputs = model.predict_on_batch(inputs) - for layer in model.layers: - if layer.__class__.__name__ == "FakeQuant": - min_value = layer.min_value - max_value = layer.max_value - if layer.name not in results: - results[layer.name] = {"min": [min_value], "max": [max_value]} + _ = model.predict_on_batch(inputs) + json_model = copy.deepcopy(json.loads(model.to_json())) + config = json_model["config"] + layers = config["layers"] + for layer in layers: + if layer["class_name"] == "FakeQuant": + min_value = layer["config"]["min_value"] + max_value = layer["config"]["max_value"] + assert min_value < max_value, "The min value must be lower than the max value in quantization." + + if layer["config"]["name"] not in results: + results[layer["config"]["name"]] = {"min": [min_value], "max": [max_value]} else: - results[layer.name]["min"].append(min_value) - results[layer.name]["max"].append(max_value) + results[layer["config"]["name"]]["min"].append(min_value) + results[layer["config"]["name"]]["max"].append(max_value) if idx + 1 == calib_interation: break - q_layers = [] - inbound_reverse_map = {} + qdq_layer_nums = 0 + qdq_layers_dict = {} + quantized_layers_dict = {} for idx, layer in enumerate(model.layers): if layer.__class__.__name__ == "FakeQuant": min_value = min(results[layer.name]["min"]) max_value = max(results[layer.name]["max"]) + quantize_layer = Quantize( - name = "quantize_" + str(idx), + name = "quantize_" + str(qdq_layer_nums), min_range = min_value, max_range = max_value, T = layer.T, ) dequantize_layer = DeQuantize( - name = "dequantize_" + str(idx), + name = "dequantize_" + str(qdq_layer_nums), min_range = min_value, max_range = max_value, ) - if hasattr(layer, "inbound_nodes"): - quantize_layer.inbound_nodes = layer.inbound_nodes - dequantize_layer.inbound_nodes[0].inbound_layers = quantize_layer - # find the conv/dense layer from fake quant map and - # change the conv/dense node inbound to dequantize - layer_name = self.inbound_nodes_map[layer.name].name - inbound_reverse_map[layer_name] = dequantize_layer - - q_layers.append(quantize_layer) - q_layers.append(dequantize_layer) + qdq_layer_nums += 1 + output_layer_name = fq_output_layers[layer.name] + qdq_layers_dict[output_layer_name] = [quantize_layer, dequantize_layer] elif ( layer.__class__.__name__ in self.supported_op and layer.name in self.quantize_config["op_wise_config"] ): # index 0 is weight, index 1 is bias q_layer_class = "Q" + layer.__class__.__name__ - # this is for inbounds search - q_name = layer.name # for layers that have weights if layer.name in self.layer_weights: kernel = self.layer_weights[layer.name][0] @@ -416,6 +437,7 @@ def _calibrate(self, model, dataloader, calib_interation): t_dim.extend(dim) channel_size = kernel.shape[-1] kernel_channel = kernel.transpose(t_dim).reshape(channel_size, -1) + layer.min_value = np.min(kernel_channel, axis=1).tolist() layer.max_value = np.max(kernel_channel, axis=1).tolist() else: @@ -423,87 +445,20 @@ def _calibrate(self, model, dataloader, calib_interation): # cause no kernel weights for this layer layer.min_value = [-10000] layer.max_value = [10000] - layer.name = q_name - layer_config = layer.get_config() - layer_config["name"] = q_name - q_layer = self.custom_layers[q_layer_class].from_config(layer_config) - if hasattr(layer, "inbound_nodes"): - q_layer.inbound_nodes = inbound_reverse_map[layer.name] - q_layers.append(q_layer) - else: - q_layers.append(layer) - - quantized_model = self._rebuild_model_from_layers(model ,q_layers) - return quantized_model - - def convert_bf16(self): - """Execute the BF16 conversion.""" - tf.keras.mixed_precision.set_global_policy("mixed_bfloat16") - model = self.pre_optimized_object - - for layer in model.layers: - if layer.name in self.bf16_ops: - layer.dtype = "mixed_bfloat16" - - model.save(self.tmp_dir) - converted_model = tf.keras.saved_model.load(self.tmp_dir) - tf.keras.mixed_precision.set_global_policy("float32") - return converted_model - - # (TODO) choose the properly quantize mode - def _check_quantize_mode(self, model): - """Check what quantize mode to use.""" - for idx, layer in enumerate(model.layers): - if "ReLU" in layer.__class__.__name__: - return "MIN_FIRST" - return "SCALED" - - def _rebuild_model_from_layers(self, model, layers): - """Insert a layer before the target layer.""" - model_outputs = [] - input_layer_dict = {} - output_tensor_dict = {layers[0].name: model.input} - - for layer in layers: - for node in layer._outbound_nodes: - layer_name = node.outbound_layer.name - if layer_name not in input_layer_dict: - input_layer_dict[layer_name] = [layer.name] - else: - input_layer_dict[layer_name].append(layer.name) + from neural_compressor.tensorflow.keras.layers import layer_initializer_dict + + q_layer = layer_initializer_dict[q_layer_class](layer) + quantized_layers_dict[layer.name] = q_layer - for layer in layers[1:]: - input_tensors = [output_tensor_dict[input_layer] - for input_layer in input_layer_dict[layer.name]] - if len(input_tensors) == 1: - input_tensors = input_tensors[0] - x = layer(input_tensors) - output_tensor_dict[layer.name] = x - if layer_name in model.output_names: - model_outputs.append(x) + qdq_surgery = KerasSurgery(self.pre_optimized_model) + quantized_model = qdq_surgery.insert_quant_layers(qdq_layers_dict, quantized_layers_dict) + quantized_model = self._set_weights(quantized_model, self.layer_weights) - new_model = tf.keras.models.Model(inputs=model.inputs, outputs=model_outputs) - new_model = self._set_weights(new_model, self.layer_weights) - new_model.save(self.tmp_dir) - return tf.keras.saved_model.load(self.tmp_dir) - - # set fp32 weights to qmodel - def _set_weights(self, qmodel, layer_weights): - for qlayer in qmodel.layers: - if qlayer.get_weights(): - if qlayer.name in layer_weights: - qlayer.set_weights(layer_weights[qlayer.name]) - else: - hit_layer = False - for sub_layer in qlayer.submodules: - if sub_layer.name in layer_weights: - qlayer.set_weights(layer_weights[sub_layer.name]) - hit_layer = True - break - if not hit_layer: - raise ValueError("Can not match the module weights....") - return qmodel + quantized_model.save(self.tmp_dir) + quantized_model = tf.keras.models.load_model(self.tmp_dir) + + return quantized_model @dump_elapsed_time(customized_msg="Model inference") def evaluate( @@ -581,11 +536,11 @@ def query_fw_capability(self, model): other_config = copy.deepcopy(op_capability["int8"]["default"]) # # get fp32 layer weights - keras_object = model + self.fp32_model = model self.conv_weights = {} self.bn_weights = {} self.layer_weights = {} - for layer in keras_object.layers: + for layer in self.fp32_model.layers: if layer.get_weights(): if ( isinstance(layer, tf.keras.layers.Conv2D) @@ -596,13 +551,12 @@ def query_fw_capability(self, model): elif isinstance(layer, tf.keras.layers.BatchNormalization): self.bn_weights[layer.name] = copy.deepcopy(layer.get_weights()) self.layer_weights[layer.name] = copy.deepcopy(layer.get_weights()) - self.pre_optimized_object = self._pre_optimize(keras_object) - - self.fp32_layers = keras_object.layers + self._check_quantize_format(self.fp32_model) + self.pre_optimized_model = self._fuse_bn(self.fp32_model) quantizable_layer_details = OrderedDict() - for layer in self.fp32_layers: + for layer in self.fp32_model.layers: layer_class = layer.__class__.__name__ if layer_class == "Conv2D": quantizable_layer_details[(layer.name, layer_class)] = [conv_config, bf16_config, fp32_config] @@ -639,6 +593,54 @@ def get_optype_wise_ability(self, quantizable_op_details): res[op[1]]["weight"] = quantizable_op_details[op][0]["weight"] return res + def tuning_cfg_to_fw(self, tuning_cfg): + """Parse tune_config and set framework variables. + + Args: + tuning_cfg (dict): The dict of tunning config. + """ + self.quantize_config["calib_iteration"] = tuning_cfg["calib_iteration"] + self.quantize_config["device"] = self.device + self.quantize_config["advance"] = deep_get(tuning_cfg, "advance") + fp32_ops = [] + bf16_ops = [] + bf16_type = set(self.query_handler.get_op_types_by_precision(precision="bf16")) + dispatched_op_names = [j[0] for j in tuning_cfg["op"]] + invalid_op_names = [i for i in self.quantize_config["op_wise_config"] if i not in dispatched_op_names] + + for op_name in invalid_op_names: + self.quantize_config["op_wise_config"].pop(op_name) + + for each_op_info in tuning_cfg["op"]: + op_name = each_op_info[0] + + if tuning_cfg["op"][each_op_info]["activation"]["dtype"] == "bf16": + if each_op_info[1] in bf16_type: + bf16_ops.append(op_name) + continue + + if tuning_cfg["op"][each_op_info]["activation"]["dtype"] == "fp32": + if op_name in self.quantize_config["op_wise_config"]: + self.quantize_config["op_wise_config"].pop(op_name) + fp32_ops.append(op_name) + continue + + is_perchannel = False + bit = None + if "weight" in tuning_cfg["op"][each_op_info]: + is_perchannel = tuning_cfg["op"][each_op_info]["weight"]["granularity"] == "per_channel" + # bit = tuning_cfg['op'][each_op_info]['weight']['bit'] + weight_bit = bit if bit else 7.0 + algorithm = tuning_cfg["op"][each_op_info]["activation"]["algorithm"] + is_asymmetric = False + if "activation" in tuning_cfg["op"][each_op_info]: + is_asymmetric = tuning_cfg["op"][each_op_info]["activation"]["scheme"] == "asym" + self.quantize_config["op_wise_config"][op_name] = (is_perchannel, algorithm, is_asymmetric, weight_bit) + self.bf16_ops = bf16_ops + if self.bf16_ops: + self.bf16_ops.pop(-1) + self.fp32_ops = fp32_ops + class KerasQuery: """Class that queries configs from yaml settings.""" @@ -734,7 +736,7 @@ def get_op_types_by_precision(self, precision): class KerasConfigConverter: """Convert `StaticQuantConfig` to the format used by static quant algo.""" - support_int8_weight = {"Dense", "Conv2d", "DepthwiseConv2D", "SeparableConv2D"} + support_int8_weight = {"Dense", "Conv2D", "DepthwiseConv2D", "SeparableConv2D"} def __init__(self, quant_config: StaticQuantConfig, calib_iteration: int): """Init parser for keras static quant config. @@ -783,3 +785,110 @@ def parse_to_tune_cfg(self) -> Dict: tune_cfg["calib_iteration"] = self.calib_iteration return tune_cfg + + +class KerasSurgery: + """The class that inserts FakeQuant or QDQ layers before the target layers.""" + + def __init__(self, model): + """Init the KerasSurgery class. + + Args: + model: the model to be modified. + """ + self.model_outputs = [] + self.model = copy.deepcopy(model) + + def _create_input_dict(self, fuse_layers=None, conv_weights_keys=None): + """Create a input_layer_dict from model. + + Args: + fuse_layers: The layers in which fused BNs have been excluded, defualt to be None. + conv_weights_keys: The names of conv layers where BNs are going to be fused, defualt to be None. + + Returns: + input_layer_dict: The dict that mapping for layer names to their input layer names. + """ + input_layer_dict = {} + layers = fuse_layers if fuse_layers else self.model.layers + for layer in layers: + for node in layer._outbound_nodes: + out_layer = node.outbound_layer + layer_name = out_layer.name + if conv_weights_keys and out_layer.__class__.__name__ in ("BatchNormalization") and \ + out_layer.inbound_nodes[0].inbound_layers.name in conv_weights_keys: + layer_name = out_layer._outbound_nodes[0].outbound_layer.name + if layer_name not in input_layer_dict: + input_layer_dict[layer_name] = [layer.name] + else: + input_layer_dict[layer_name].append(layer.name) + + return input_layer_dict + + def fuse_bn_layers(self, fuse_layers, conv_weights_keys): + """Fuse BN layers and rebuild the model. + + Args: + fuse_layers: The layers in which fused BNs have been excluded. + conv_weights_keys: The names of conv layers where BNs are going to be fused. + """ + self.input_layer_dict = self._create_input_dict(fuse_layers, conv_weights_keys) + has_input_layer = fuse_layers[0].__class__.__name__ == "InputLayer" + output_tensor_dict = {fuse_layers[0].name: self.model.input} if has_input_layer \ + else {"keras.Input": self.model.input} + + for idx, layer in enumerate(fuse_layers): + if idx == 0 and has_input_layer: + continue + + input_tensors = output_tensor_dict["keras.Input"] if idx == 0 \ + else [output_tensor_dict[input_layer] for input_layer in self.input_layer_dict[layer.name]] + if len(input_tensors) == 1: + input_tensors = input_tensors[0] + + x = layer(input_tensors) + + output_tensor_dict[layer.name] = x + if layer.name in self.model.output_names: + self.model_outputs.append(x) + + return tf.keras.models.Model(inputs=self.model.inputs, outputs=self.model_outputs) + + def insert_quant_layers(self, qdq_layer_dict, q_layer_dict=None): + """Insert FakeQuant or QDQ layers before the target layers and replace + Keras layers to Quantized layers. + + Args: + qdq_layer_dict: The dict mapping from layers to be quantized to the FakeQuant layer or QDQ layers + that are going to be inserted before them. + q_layer_dict: The dict mapping from layers to be replacement to the quantized layers. + """ + self.input_layer_dict = self._create_input_dict() + layers = self.model.layers + has_input_layer = layers[0].__class__.__name__ == "InputLayer" + output_tensor_dict = {layers[0].name: layers[0](self.model.input)} if has_input_layer \ + else {"keras.Input": self.model.input} + + for idx, layer in enumerate(layers): + if idx == 0 and has_input_layer: + continue + + input_tensors = output_tensor_dict["keras.Input"] if idx == 0 \ + else [output_tensor_dict[input_layer] for input_layer in self.input_layer_dict[layer.name]] + if len(input_tensors) == 1: + input_tensors = input_tensors[0] + + if layer.name in qdq_layer_dict: + x = input_tensors + for inserted_layer in qdq_layer_dict[layer.name]: + x = inserted_layer(x) + cur_layer = layer if not q_layer_dict else q_layer_dict[layer.name] + x = cur_layer(x) + else: + x = layer(input_tensors) + + output_tensor_dict[layer.name] = x + if layer.name in self.model.output_names: + self.model_outputs.append(x) + + return tf.keras.models.Model(inputs=self.model.inputs, outputs=self.model_outputs) diff --git a/neural_compressor/tensorflow/keras/layers/__init__.py b/neural_compressor/tensorflow/keras/layers/__init__.py index 2be4fd9417e..0b4fe9030ac 100644 --- a/neural_compressor/tensorflow/keras/layers/__init__.py +++ b/neural_compressor/tensorflow/keras/layers/__init__.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- # -# Copyright (c) 2022 Intel Corporation +# Copyright (c) 2024 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,3 +21,4 @@ from neural_compressor.tensorflow.keras.layers.pool2d import QAvgPool2D, QMaxPool2D from neural_compressor.tensorflow.keras.layers.quantizer import DeQuantize, FakeQuant, Quantize from neural_compressor.tensorflow.keras.layers.separable_conv2d import QSeparableConv2D +from neural_compressor.tensorflow.keras.layers.layer_initializer import layer_initializer_dict diff --git a/neural_compressor/tensorflow/keras/layers/conv2d.py b/neural_compressor/tensorflow/keras/layers/conv2d.py index 3bcf1b07b86..84e7b3b01ef 100644 --- a/neural_compressor/tensorflow/keras/layers/conv2d.py +++ b/neural_compressor/tensorflow/keras/layers/conv2d.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- # -# Copyright (c) 2022 Intel Corporation +# Copyright (c) 2024 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -23,7 +23,9 @@ from neural_compressor.tensorflow.utils import version1_gte_version2 -if version1_gte_version2(tf.__version__, "2.13.0"): +if version1_gte_version2(tf.__version__, "2.16.1"): + from keras.src.layers.convolutional.base_conv import BaseConv as Conv # pylint: disable=E0401 +elif version1_gte_version2(tf.__version__, "2.13.0"): from keras.src.layers.convolutional.base_conv import Conv # pylint: disable=E0401 else: from keras.layers.convolutional.base_conv import Conv # pylint: disable=E0401 @@ -32,6 +34,7 @@ class QConv2D(Conv): def __init__( self, + name, filters, kernel_size, strides=(1, 1), @@ -48,11 +51,12 @@ def __init__( activity_regularizer=None, kernel_constraint=None, bias_constraint=None, - min_value=-10000, - max_value=10000, + min_value=None, + max_value=None, **kwargs ): super(QConv2D, self).__init__( + name=name, rank=2, filters=filters, kernel_size=kernel_size, @@ -72,10 +76,17 @@ def __init__( bias_constraint=constraints.get(bias_constraint), **kwargs ) - self.min_value = json.loads(min_value) - self.max_value = json.loads(max_value) + self.min_value = min_value + self.max_value = max_value def call(self, inputs): + kernel_size = self.kernel.shape[-1] + + if not self.min_value: + self.min_value = [-10000]*kernel_size + if not self.max_value: + self.max_value = [10000]*kernel_size + # add the Q/DQ here kernel, _, _ = quantization.quantize( self.kernel, self.min_value, self.max_value, tf.qint8, axis=3, mode="SCALED" @@ -107,3 +118,69 @@ def call(self, inputs): @classmethod def from_config(cls, config): return cls(**config) + + +def initialize_int8_conv2d(fp32_layer): + kwargs = fp32_layer.get_config() + + if "name" in kwargs: + del kwargs["name"] + if "filters" in kwargs: + del kwargs["filters"] + if "kernel_size" in kwargs: + del kwargs["kernel_size"] + if "strides" in kwargs: + del kwargs["strides"] + if "padding" in kwargs: + del kwargs["padding"] + if "data_format" in kwargs: + del kwargs["data_format"] + if "dilation_rate" in kwargs: + del kwargs["dilation_rate"] + if "groups" in kwargs: + del kwargs["groups"] + if "activation" in kwargs: + del kwargs["activation"] + if "use_bias" in kwargs: + del kwargs["use_bias"] + if "kernel_initializer" in kwargs: + del kwargs["kernel_initializer"] + if "bias_initializer" in kwargs: + del kwargs["bias_initializer"] + if "kernel_regularizer" in kwargs: + del kwargs["kernel_regularizer"] + if "activity_regularizer" in kwargs: + del kwargs["activity_regularizer"] + if "bias_regularizer" in kwargs: + del kwargs["bias_regularizer"] + if "kernel_constraint" in kwargs: + del kwargs["kernel_constraint"] + if "bias_constraint" in kwargs: + del kwargs["bias_constraint"] + if "min_value" in kwargs: + del kwargs["min_value"] + if "max_value" in kwargs: + del kwargs["max_value"] + + return QConv2D( + name=fp32_layer.name, + filters=fp32_layer.filters, + kernel_size=fp32_layer.kernel_size, + strides=fp32_layer.strides, + padding=fp32_layer.padding, + data_format=fp32_layer.data_format, + dilation_rate=fp32_layer.dilation_rate, + groups=fp32_layer.groups, + activation=fp32_layer.activation, + use_bias=fp32_layer.use_bias, + kernel_initializer=fp32_layer.kernel_initializer, + bias_initializer=fp32_layer.bias_initializer, + kernel_regularizer=fp32_layer.kernel_regularizer, + bias_regularizer=fp32_layer.bias_regularizer, + activity_regularizer=fp32_layer.activity_regularizer, + kernel_constraint=fp32_layer.kernel_constraint, + bias_constraint=fp32_layer.bias_constraint, + min_value=fp32_layer.min_value, + max_value=fp32_layer.max_value, + **kwargs + ) diff --git a/neural_compressor/tensorflow/keras/layers/dense.py b/neural_compressor/tensorflow/keras/layers/dense.py index b97e9759b70..007ad5dc999 100644 --- a/neural_compressor/tensorflow/keras/layers/dense.py +++ b/neural_compressor/tensorflow/keras/layers/dense.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- # -# Copyright (c) 2022 Intel Corporation +# Copyright (c) 2024 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -26,6 +26,7 @@ class QDense(Dense): def __init__( self, + name, units, activation=None, use_bias=True, @@ -36,11 +37,12 @@ def __init__( activity_regularizer=None, kernel_constraint=None, bias_constraint=None, - min_value=-10000, - max_value=10000, + min_value=None, + max_value=None, **kwargs ): super(QDense, self).__init__( + name=name, units=units, activation=activation, use_bias=use_bias, @@ -53,10 +55,17 @@ def __init__( bias_constraint=bias_constraint, **kwargs ) - self.min_value = json.loads(min_value) - self.max_value = json.loads(max_value) + self.min_value = min_value + self.max_value = max_value def call(self, inputs): + kernel_size = self.kernel.shape[-1] + + if not self.min_value: + self.min_value = [-10000]*kernel_size + if not self.max_value: + self.max_value = [10000]*kernel_size + # add the Q/DQ here kernel, _, _ = quantization.quantize( self.kernel, @@ -66,6 +75,7 @@ def call(self, inputs): axis=1, mode="SCALED", ) + kernel = quantization.dequantize( kernel, self.min_value, @@ -80,3 +90,53 @@ def call(self, inputs): if self.activation is not None: outputs = self.activation(outputs) return outputs + + +def initialize_int8_dense(fp32_layer): + kwargs = fp32_layer.get_config() + + if "name" in kwargs: + del kwargs["name"] + if "units" in kwargs: + del kwargs["units"] + if "activation" in kwargs: + del kwargs["activation"] + if "use_bias" in kwargs: + del kwargs["use_bias"] + if "kernel_initializer" in kwargs: + del kwargs["kernel_initializer"] + if "bias_initializer" in kwargs: + del kwargs["bias_initializer"] + if "kernel_regularizer" in kwargs: + del kwargs["kernel_regularizer"] + if "activity_regularizer" in kwargs: + del kwargs["activity_regularizer"] + if "bias_regularizer" in kwargs: + del kwargs["bias_regularizer"] + if "kernel_constraint" in kwargs: + del kwargs["kernel_constraint"] + if "bias_constraint" in kwargs: + del kwargs["bias_constraint"] + if "min_value" in kwargs: + del kwargs["min_value"] + if "max_value" in kwargs: + del kwargs["max_value"] + + q_layer = QDense( + name=fp32_layer.name, + units=fp32_layer.units, + activation=fp32_layer.activation, + use_bias=fp32_layer.use_bias, + kernel_initializer=fp32_layer.kernel_initializer, + bias_initializer=fp32_layer.bias_initializer, + kernel_regularizer=fp32_layer.kernel_regularizer, + bias_regularizer=fp32_layer.bias_regularizer, + activity_regularizer=fp32_layer.activity_regularizer, + kernel_constraint=fp32_layer.kernel_constraint, + bias_constraint=fp32_layer.bias_constraint, + min_value=fp32_layer.min_value, + max_value=fp32_layer.max_value, + **kwargs + ) + + return q_layer diff --git a/neural_compressor/tensorflow/keras/layers/depthwise_conv2d.py b/neural_compressor/tensorflow/keras/layers/depthwise_conv2d.py index 1de9d8bf792..72f736456f6 100644 --- a/neural_compressor/tensorflow/keras/layers/depthwise_conv2d.py +++ b/neural_compressor/tensorflow/keras/layers/depthwise_conv2d.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- # -# Copyright (c) 2022 Intel Corporation +# Copyright (c) 2024 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -23,7 +23,10 @@ from neural_compressor.tensorflow.utils import version1_gte_version2 -if version1_gte_version2(tf.__version__, "2.13.0"): +if version1_gte_version2(tf.__version__, "2.16.1"): + from keras.src import ops + from keras.src.layers.convolutional.base_depthwise_conv import BaseDepthwiseConv # pylint: disable=E0401 +elif version1_gte_version2(tf.__version__, "2.13.0"): from keras.src.layers.convolutional.base_depthwise_conv import DepthwiseConv # pylint: disable=E0401 from keras.src.utils import conv_utils, tf_utils # pylint: disable=E0401 else: @@ -31,109 +34,266 @@ from keras.utils import conv_utils, tf_utils # pylint: disable=E0401 -class QDepthwiseConv2D(DepthwiseConv): - def __init__( - self, - kernel_size, - min_value, - max_value, - strides=(1, 1), - padding="valid", - depth_multiplier=1, - data_format=None, - dilation_rate=(1, 1), - activation=None, - use_bias=True, - depthwise_initializer="glorot_uniform", - bias_initializer="zeros", - depthwise_regularizer=None, - bias_regularizer=None, - activity_regularizer=None, - depthwise_constraint=None, - bias_constraint=None, - **kwargs - ): - super().__init__( - 2, - kernel_size=kernel_size, - strides=strides, - padding=padding, - depth_multiplier=depth_multiplier, - data_format=data_format, - dilation_rate=dilation_rate, - activation=activation, - use_bias=use_bias, - depthwise_initializer=depthwise_initializer, - bias_initializer=bias_initializer, - depthwise_regularizer=depthwise_regularizer, - bias_regularizer=bias_regularizer, - activity_regularizer=activity_regularizer, - depthwise_constraint=depthwise_constraint, - bias_constraint=bias_constraint, +if version1_gte_version2(tf.__version__, "2.16.1"): + class QDepthwiseConv2D(BaseDepthwiseConv): + def __init__( + self, + kernel_size, + strides=(1, 1), + padding="valid", + depth_multiplier=1, + data_format=None, + dilation_rate=(1, 1), + activation=None, + use_bias=True, + depthwise_initializer="glorot_uniform", + bias_initializer="zeros", + depthwise_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + depthwise_constraint=None, + bias_constraint=None, + min_value=None, + max_value=None, + **kwargs ): + super().__init__( + 2, + kernel_size=kernel_size, + strides=strides, + padding=padding, + depth_multiplier=depth_multiplier, + data_format=data_format, + dilation_rate=dilation_rate, + activation=activation, + use_bias=use_bias, + depthwise_initializer=depthwise_initializer, + bias_initializer=bias_initializer, + depthwise_regularizer=depthwise_regularizer, + bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + depthwise_constraint=depthwise_constraint, + bias_constraint=bias_constraint, + **kwargs + ) + self.min_value = min_value + self.max_value = max_value + + def call(self, inputs): + kernel_size = self.kernel.shape[-1] + + if not self.min_value: + self.min_value = [-10000]*kernel_size + if not self.max_value: + self.max_value = [10000]*kernel_size + + # add the Q/DQ here + kernel, _, _ = quantization.quantize( + self.depthwise_kernel, self.min_value, self.max_value, tf.qint8, axis=3, mode="SCALED" + ) + kernel = quantization.dequantize( + kernel, + self.min_value, + self.max_value, + axis=3, + mode="SCALED", + ) + + input_channel = self._get_input_channel(inputs.shape) + outputs = ops.depthwise_conv( + inputs, + self.kernel, + strides=self.strides, + padding=self.padding, + dilation_rate=self.dilation_rate, + data_format=self.data_format, + ) + + if self.use_bias: + if self.data_format == "channels_last": + bias_shape = (1,) * (self.rank + 1) + (self.depth_multiplier * input_channel,) + else: + bias_shape = (1, self.depth_multiplier * input_channel) + (1,) * self.rank + bias = ops.reshape(self.bias, bias_shape) + outputs += bias + + if self.activation is not None: + return self.activation(outputs) + return +else: + class QDepthwiseConv2D(DepthwiseConv): + def __init__( + self, + kernel_size, + strides=(1, 1), + padding="valid", + depth_multiplier=1, + data_format=None, + dilation_rate=(1, 1), + activation=None, + use_bias=True, + depthwise_initializer="glorot_uniform", + bias_initializer="zeros", + depthwise_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + depthwise_constraint=None, + bias_constraint=None, + min_value=None, + max_value=None, **kwargs - ) - self.min_value = json.loads(min_value) - self.max_value = json.loads(max_value) - - def call(self, inputs): - # add the Q/DQ here - kernel, _, _ = quantization.quantize( - self.depthwise_kernel, self.min_value, self.max_value, tf.qint8, axis=3, mode="SCALED" - ) - kernel = quantization.dequantize( - kernel, - self.min_value, - self.max_value, - axis=3, - mode="SCALED", - ) - outputs = tf.keras.backend.depthwise_conv2d( - inputs, - kernel, - strides=self.strides, - padding=self.padding, - data_format=self.data_format, - dilation_rate=self.dilation_rate, - ) - - if self.use_bias: - outputs = tf.keras.backend.bias_add(outputs, self.bias, data_format=self.data_format) - - if self.activation is not None: - return self.activation(outputs) - - return outputs - - @classmethod - def from_config(cls, config): - return cls(**config) - - @tf_utils.shape_type_conversion - def compute_output_shape(self, input_shape): - if self.data_format == "channels_first": - rows = input_shape[2] - cols = input_shape[3] - out_filters = input_shape[1] * self.depth_multiplier - elif self.data_format == "channels_last": - rows = input_shape[1] - cols = input_shape[2] - out_filters = input_shape[3] * self.depth_multiplier - - rows = conv_utils.conv_output_length( - rows, - self.kernel_size[0], - self.padding, - self.strides[0], - self.dilation_rate[0], - ) - cols = conv_utils.conv_output_length( - cols, - self.kernel_size[1], - self.padding, - self.strides[1], - self.dilation_rate[1], - ) - if self.data_format == "channels_first": - return (input_shape[0], out_filters, rows, cols) - elif self.data_format == "channels_last": - return (input_shape[0], rows, cols, out_filters) + ): + super().__init__( + 2, + kernel_size=kernel_size, + strides=strides, + padding=padding, + depth_multiplier=depth_multiplier, + data_format=data_format, + dilation_rate=dilation_rate, + activation=activation, + use_bias=use_bias, + depthwise_initializer=depthwise_initializer, + bias_initializer=bias_initializer, + depthwise_regularizer=depthwise_regularizer, + bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + depthwise_constraint=depthwise_constraint, + bias_constraint=bias_constraint, + **kwargs + ) + self.min_value = min_value + self.max_value = max_value + + def call(self, inputs): + depthwise_kernel_size = self.depthwise_kernel.shape[-1] + + if not self.min_value: + self.min_value = [-10000]*depthwise_kernel_size + if not self.max_value: + self.max_value = [10000]*depthwise_kernel_size + + # add the Q/DQ here + kernel, _, _ = quantization.quantize( + self.depthwise_kernel, self.min_value, self.max_value, tf.qint8, axis=3, mode="SCALED" + ) + kernel = quantization.dequantize( + kernel, + self.min_value, + self.max_value, + axis=3, + mode="SCALED", + ) + outputs = tf.keras.backend.depthwise_conv2d( + inputs, + kernel, + strides=self.strides, + padding=self.padding, + data_format=self.data_format, + dilation_rate=self.dilation_rate, + ) + + if self.use_bias: + outputs = tf.keras.backend.bias_add(outputs, self.bias, data_format=self.data_format) + + if self.activation is not None: + return self.activation(outputs) + + return outputs + + @classmethod + def from_config(cls, config): + return cls(**config) + + @tf_utils.shape_type_conversion + def compute_output_shape(self, input_shape): + if self.data_format == "channels_first": + rows = input_shape[2] + cols = input_shape[3] + out_filters = input_shape[1] * self.depth_multiplier + elif self.data_format == "channels_last": + rows = input_shape[1] + cols = input_shape[2] + out_filters = input_shape[3] * self.depth_multiplier + + rows = conv_utils.conv_output_length( + rows, + self.kernel_size[0], + self.padding, + self.strides[0], + self.dilation_rate[0], + ) + cols = conv_utils.conv_output_length( + cols, + self.kernel_size[1], + self.padding, + self.strides[1], + self.dilation_rate[1], + ) + if self.data_format == "channels_first": + return (input_shape[0], out_filters, rows, cols) + elif self.data_format == "channels_last": + return (input_shape[0], rows, cols, out_filters) + + +def initialize_int8_depthwise_conv2d(fp32_layer): + kwargs = fp32_layer.get_config() + q_name = fp32_layer.name + + if "name" in kwargs: + del kwargs["name"] + if "kernel_size" in kwargs: + del kwargs["kernel_size"] + if "strides" in kwargs: + del kwargs["strides"] + if "padding" in kwargs: + del kwargs["padding"] + if "depth_multiplier" in kwargs: + del kwargs["depth_multiplier"] + if "data_format" in kwargs: + del kwargs["data_format"] + if "dilation_rate" in kwargs: + del kwargs["dilation_rate"] + if "activation" in kwargs: + del kwargs["activation"] + if "use_bias" in kwargs: + del kwargs["use_bias"] + if "depthwise_initializer" in kwargs: + del kwargs["depthwise_initializer"] + if "bias_initializer" in kwargs: + del kwargs["bias_initializer"] + if "depthwise_regularizer" in kwargs: + del kwargs["depthwise_regularizer"] + if "activity_regularizer" in kwargs: + del kwargs["activity_regularizer"] + if "bias_regularizer" in kwargs: + del kwargs["bias_regularizer"] + if "depthwise_constraint" in kwargs: + del kwargs["depthwise_constraint"] + if "bias_constraint" in kwargs: + del kwargs["bias_constraint"] + if "min_value" in kwargs: + del kwargs["min_value"] + if "max_value" in kwargs: + del kwargs["max_value"] + + return QDepthwiseConv2D( + name=q_name, + kernel_size=fp32_layer.kernel_size, + strides=fp32_layer.strides, + padding=fp32_layer.padding, + depth_multiplier=fp32_layer.depth_multiplier, + data_format=fp32_layer.data_format, + dilation_rate=fp32_layer.dilation_rate, + activation=fp32_layer.activation, + use_bias=fp32_layer.use_bias, + depthwise_initializer=fp32_layer.depthwise_initializer, + bias_initializer=fp32_layer.bias_initializer, + depthwise_regularizer=fp32_layer.depthwise_regularizer, + bias_regularizer=fp32_layer.bias_regularizer, + activity_regularizer=fp32_layer.activity_regularizer, + depthwise_constraint=fp32_layer.depthwise_constraint, + bias_constraint=fp32_layer.bias_constraint, + min_value=fp32_layer.min_value, + max_value=fp32_layer.max_value, + **kwargs + ) diff --git a/neural_compressor/tensorflow/keras/layers/layer_initializer.py b/neural_compressor/tensorflow/keras/layers/layer_initializer.py new file mode 100644 index 00000000000..d1db0eb3504 --- /dev/null +++ b/neural_compressor/tensorflow/keras/layers/layer_initializer.py @@ -0,0 +1,33 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# +# Copyright (c) 2024 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from neural_compressor.tensorflow.keras.layers.conv2d import initialize_int8_conv2d +from neural_compressor.tensorflow.keras.layers.dense import initialize_int8_dense +from neural_compressor.tensorflow.keras.layers.depthwise_conv2d import initialize_int8_depthwise_conv2d +from neural_compressor.tensorflow.keras.layers.pool2d import initialize_int8_avgpool, initialize_int8_maxpool +from neural_compressor.tensorflow.keras.layers.separable_conv2d import initialize_int8_separable_conv2d + +layer_initializer_dict = { + "QAvgPool2D": initialize_int8_avgpool, + "QAveragePooling2D": initialize_int8_avgpool, + "QMaxPool2D": initialize_int8_maxpool, + "QMaxPooling2D": initialize_int8_maxpool, + "QSeparableConv2D": initialize_int8_separable_conv2d, + "QDepthwiseConv2D": initialize_int8_depthwise_conv2d, + "QConv2D": initialize_int8_conv2d, + "QDense": initialize_int8_dense, +} diff --git a/neural_compressor/tensorflow/keras/layers/pool2d.py b/neural_compressor/tensorflow/keras/layers/pool2d.py index 409c16b9305..8fb7889e3c7 100644 --- a/neural_compressor/tensorflow/keras/layers/pool2d.py +++ b/neural_compressor/tensorflow/keras/layers/pool2d.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- # -# Copyright (c) 2022 Intel Corporation +# Copyright (c) 2024 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -26,6 +26,7 @@ class QAvgPool2D(AveragePooling2D): def __init__( self, + name, pool_size=(2, 2), strides=None, padding="valid", @@ -35,15 +36,21 @@ def __init__( **kwargs ): super(QAvgPool2D, self).__init__( - pool_size=pool_size, strides=strides, padding=padding, data_format=data_format, **kwargs + name=name, + pool_size=pool_size, + strides=strides, + padding=padding, + data_format=data_format, + **kwargs ) - self.min_value = json.loads(min_value) - self.max_value = json.loads(max_value) + self.min_value = min_value + self.max_value = max_value class QMaxPool2D(MaxPooling2D): def __init__( self, + name, pool_size=(2, 2), strides=None, padding="valid", @@ -53,7 +60,75 @@ def __init__( **kwargs ): super(QMaxPool2D, self).__init__( - pool_size=pool_size, strides=strides, padding=padding, data_format=data_format, **kwargs + name=name, + pool_size=pool_size, + strides=strides, + padding=padding, + data_format=data_format, + **kwargs ) - self.min_value = json.loads(min_value) - self.max_value = json.loads(max_value) + self.min_value = min_value + self.max_value = max_value + + +def initialize_int8_avgpool(fp32_layer): + kwargs = fp32_layer.get_config() + + if "name" in kwargs: + del kwargs["name"] + if "pool_size" in kwargs: + del kwargs["pool_size"] + if "strides" in kwargs: + del kwargs["strides"] + if "padding" in kwargs: + del kwargs["padding"] + if "data_format" in kwargs: + del kwargs["data_format"] + if "min_value" in kwargs: + del kwargs["min_value"] + if "max_value" in kwargs: + del kwargs["max_value"] + + q_layer = QAvgPool2D( + name=fp32_layer.name, + pool_size=fp32_layer.pool_size, + strides=fp32_layer.strides, + padding=fp32_layer.padding, + data_format=fp32_layer.data_format, + min_value=fp32_layer.min_value, + max_value=fp32_layer.max_value, + **kwargs + ) + + return q_layer + +def initialize_int8_maxpool(fp32_layer): + kwargs = fp32_layer.get_config() + + if "name" in kwargs: + del kwargs["name"] + if "pool_size" in kwargs: + del kwargs["pool_size"] + if "strides" in kwargs: + del kwargs["strides"] + if "padding" in kwargs: + del kwargs["padding"] + if "data_format" in kwargs: + del kwargs["data_format"] + if "min_value" in kwargs: + del kwargs["min_value"] + if "max_value" in kwargs: + del kwargs["max_value"] + + q_layer= QMaxPool2D( + name=fp32_layer.name, + pool_size=fp32_layer.pool_size, + strides=fp32_layer.strides, + padding=fp32_layer.padding, + data_format=fp32_layer.data_format, + min_value=fp32_layer.min_value, + max_value=fp32_layer.max_value, + **kwargs + ) + + return q_layer diff --git a/neural_compressor/tensorflow/keras/layers/quantizer.py b/neural_compressor/tensorflow/keras/layers/quantizer.py index b395870b48f..a6e31fc6a5c 100644 --- a/neural_compressor/tensorflow/keras/layers/quantizer.py +++ b/neural_compressor/tensorflow/keras/layers/quantizer.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- # -# Copyright (c) 2022 Intel Corporation +# Copyright (c) 2024 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -26,8 +26,8 @@ def __init__(self, mode="per_tensor", T="s8", **kwargs): self.mode = mode self.T = T self.axis = 1 if mode == "per_channel" else 0 - self.min_value = tf.constant(np.finfo(np.float32).max, dtype=tf.float32) - self.max_value = tf.constant(np.finfo(np.float32).min, dtype=tf.float32) + self.min_value = tf.constant(np.finfo(np.float32).min, dtype=tf.float32) + self.max_value = tf.constant(np.finfo(np.float32).max, dtype=tf.float32) def call(self, inputs): if self.mode == "per_tensor": @@ -36,8 +36,13 @@ def call(self, inputs): else: self.min_value = tf.math.reduce_min(inputs, axis=self.axis) self.max_value = tf.math.reduce_max(inputs, axis=self.axis) + return inputs + def compute_output_shape(self, input_shape): + input_shape = tf.TensorShape(input_shape).as_list() + return input_shape + @classmethod def from_config(cls, config): return cls(**config) @@ -87,6 +92,10 @@ def call(self, inputs): ) return outputs + def compute_output_shape(self, input_shape): + input_shape = tf.TensorShape(input_shape).as_list() + return input_shape + def get_config(self): return { "min_range": self.min_range, @@ -122,6 +131,10 @@ def call(self, inputs): axis=self.axis, ) + def compute_output_shape(self, input_shape): + input_shape = tf.TensorShape(input_shape).as_list() + return input_shape + def get_config(self): return { "min_range": self.min_range, diff --git a/neural_compressor/tensorflow/keras/layers/separable_conv2d.py b/neural_compressor/tensorflow/keras/layers/separable_conv2d.py index 5507d2f99d2..3c2338c2a7d 100644 --- a/neural_compressor/tensorflow/keras/layers/separable_conv2d.py +++ b/neural_compressor/tensorflow/keras/layers/separable_conv2d.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- # -# Copyright (c) 2022 Intel Corporation +# Copyright (c) 2024 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -23,7 +23,10 @@ from neural_compressor.tensorflow.utils import version1_gte_version2 -if version1_gte_version2(tf.__version__, "2.13.0"): +if version1_gte_version2(tf.__version__, "2.16.1"): + from keras.src import ops + from keras.src.layers.convolutional.base_separable_conv import BaseSeparableConv # pylint: disable=E0401 +elif version1_gte_version2(tf.__version__, "2.13.0"): from keras.src.layers.convolutional.base_separable_conv import SeparableConv # pylint: disable=E0401 from keras.src.utils import conv_utils # pylint: disable=E0401 else: @@ -31,94 +34,278 @@ from keras.utils import conv_utils # pylint: disable=E0401 -class QSeparableConv2D(SeparableConv): - def __init__( - self, - filters, - kernel_size, - min_value, - max_value, - strides=(1, 1), - padding="valid", - data_format=None, - dilation_rate=(1, 1), - depth_multiplier=1, - activation=None, - use_bias=True, - depthwise_initializer="glorot_uniform", - pointwise_initializer="glorot_uniform", - bias_initializer="zeros", - depthwise_regularizer=None, - pointwise_regularizer=None, - bias_regularizer=None, - activity_regularizer=None, - depthwise_constraint=None, - pointwise_constraint=None, - bias_constraint=None, - **kwargs - ): - super().__init__( - rank=2, - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - dilation_rate=dilation_rate, - depth_multiplier=depth_multiplier, - activation=activations.get(activation), - use_bias=use_bias, - depthwise_initializer=initializers.get(depthwise_initializer), - pointwise_initializer=initializers.get(pointwise_initializer), - bias_initializer=initializers.get(bias_initializer), - depthwise_regularizer=regularizers.get(depthwise_regularizer), - pointwise_regularizer=regularizers.get(pointwise_regularizer), - bias_regularizer=regularizers.get(bias_regularizer), - activity_regularizer=regularizers.get(activity_regularizer), - depthwise_constraint=constraints.get(depthwise_constraint), - pointwise_constraint=constraints.get(pointwise_constraint), - bias_constraint=constraints.get(bias_constraint), +if version1_gte_version2(tf.__version__, "2.16.1"): + class QSeparableConv2D(BaseSeparableConv): + def __init__( + self, + name, + filters, + kernel_size, + min_value, + max_value, + strides=(1, 1), + padding="valid", + data_format=None, + dilation_rate=(1, 1), + depth_multiplier=1, + activation=None, + use_bias=True, + depthwise_initializer="glorot_uniform", + pointwise_initializer="glorot_uniform", + bias_initializer="zeros", + depthwise_regularizer=None, + pointwise_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + depthwise_constraint=None, + pointwise_constraint=None, + bias_constraint=None, + **kwargs ): + super().__init__( + name=name, + rank=2, + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + depth_multiplier=depth_multiplier, + activation=activations.get(activation), + use_bias=use_bias, + depthwise_initializer=initializers.get(depthwise_initializer), + pointwise_initializer=initializers.get(pointwise_initializer), + bias_initializer=initializers.get(bias_initializer), + depthwise_regularizer=regularizers.get(depthwise_regularizer), + pointwise_regularizer=regularizers.get(pointwise_regularizer), + bias_regularizer=regularizers.get(bias_regularizer), + activity_regularizer=regularizers.get(activity_regularizer), + depthwise_constraint=constraints.get(depthwise_constraint), + pointwise_constraint=constraints.get(pointwise_constraint), + bias_constraint=constraints.get(bias_constraint), + min_value=None, + max_value=None, + **kwargs + ) + + self.min_value = min_value + self.max_value = max_value + + def call(self, inputs): + depthwise_kernel_size = self.depthwise_kernel.shape[-1] + + if not self.min_value: + self.min_value = [-10000]*depthwise_kernel_size + if not self.max_value: + self.max_value = [10000]*depthwise_kernel_size + + # TODO it's ugly that we can't get the point_wise min/max here + depthwise_kernel, _, _ = quantization.quantize( + self.depthwise_kernel, self.min_value, self.max_value, tf.qint8, axis=3, mode="SCALED" + ) + depthwise_kernel = quantization.dequantize( + depthwise_kernel, + self.min_value, + self.max_value, + axis=3, + mode="SCALED", + ) + + outputs = ops.separable_conv( + inputs, + self.depthwise_kernel, + self.pointwise_kernel, + strides=self.strides, + padding=self.padding, + dilation_rate=self.dilation_rate, + data_format=self.data_format, + ) + + if self.use_bias: + if self.data_format == "channels_last": + bias_shape = (1,) * (self.rank + 1) + (self.filters,) + else: + bias_shape = (1, self.filters) + (1,) * self.rank + bias = ops.reshape(self.bias, bias_shape) + outputs += bias + + if self.activation is not None: + return self.activation(outputs) + return outputs +else: + class QSeparableConv2D(SeparableConv): + def __init__( + self, + name, + filters, + kernel_size, + strides=(1, 1), + padding="valid", + data_format=None, + dilation_rate=(1, 1), + depth_multiplier=1, + activation=None, + use_bias=True, + depthwise_initializer="glorot_uniform", + pointwise_initializer="glorot_uniform", + bias_initializer="zeros", + depthwise_regularizer=None, + pointwise_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + depthwise_constraint=None, + pointwise_constraint=None, + bias_constraint=None, + min_value=None, + max_value=None, **kwargs - ) - - self.min_value = json.loads(min_value) - self.max_value = json.loads(max_value) - - def call(self, inputs): - if self.data_format == "channels_last": - strides = (1,) + self.strides + (1,) - else: - strides = (1, 1) + self.strides - # (TODO) it's ugly that we can't get the point_wise min/max here - depthwise_kernel, _, _ = quantization.quantize( - self.depthwise_kernel, self.min_value, self.max_value, tf.qint8, axis=3, mode="SCALED" - ) - depthwise_kernel = quantization.dequantize( - depthwise_kernel, - self.min_value, - self.max_value, - axis=3, - mode="SCALED", - ) - - outputs = tf.compat.v1.nn.separable_conv2d( - inputs, - depthwise_kernel, - self.pointwise_kernel, - strides=strides, - padding=self.padding.upper(), - rate=self.dilation_rate, - data_format=conv_utils.convert_data_format(self.data_format, ndim=4), - ) - - if self.use_bias: - outputs = tf.keras.backend.bias_add(outputs, self.bias, data_format=self.data_format) - - if self.activation is not None: - return self.activation(outputs) - - return outputs - - @classmethod - def from_config(cls, config): - return cls(**config) + ): + super().__init__( + name=name, + rank=2, + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + depth_multiplier=depth_multiplier, + activation=activations.get(activation), + use_bias=use_bias, + depthwise_initializer=initializers.get(depthwise_initializer), + pointwise_initializer=initializers.get(pointwise_initializer), + bias_initializer=initializers.get(bias_initializer), + depthwise_regularizer=regularizers.get(depthwise_regularizer), + pointwise_regularizer=regularizers.get(pointwise_regularizer), + bias_regularizer=regularizers.get(bias_regularizer), + activity_regularizer=regularizers.get(activity_regularizer), + depthwise_constraint=constraints.get(depthwise_constraint), + pointwise_constraint=constraints.get(pointwise_constraint), + bias_constraint=constraints.get(bias_constraint), + **kwargs + ) + + self.min_value = min_value + self.max_value = max_value + + def call(self, inputs): + depthwise_kernel_size = self.depthwise_kernel.shape[-1] + + if not self.min_value: + self.min_value = [-10000]*depthwise_kernel_size + if not self.max_value: + self.max_value = [10000]*depthwise_kernel_size + + # TODO it's ugly that we can't get the point_wise min/max here + depthwise_kernel, _, _ = quantization.quantize( + self.depthwise_kernel, self.min_value, self.max_value, tf.qint8, axis=3, mode="SCALED" + ) + depthwise_kernel = quantization.dequantize( + depthwise_kernel, + self.min_value, + self.max_value, + axis=3, + mode="SCALED", + ) + + if self.data_format == "channels_last": + strides = (1,) + self.strides + (1,) + else: + strides = (1, 1) + self.strides + + outputs = tf.compat.v1.nn.separable_conv2d( + inputs, + depthwise_kernel, + self.pointwise_kernel, + strides=strides, + padding=self.padding.upper(), + rate=self.dilation_rate, + data_format=conv_utils.convert_data_format(self.data_format, ndim=4), + ) + + if self.use_bias: + outputs = tf.keras.backend.bias_add(outputs, self.bias, data_format=self.data_format) + + if self.activation is not None: + return self.activation(outputs) + + return outputs + + @classmethod + def from_config(cls, config): + return cls(**config) + + +def initialize_int8_separable_conv2d(fp32_layer): + kwargs = fp32_layer.get_config() + + if "name" in kwargs: + del kwargs["name"] + if "filters" in kwargs: + del kwargs["filters"] + if "kernel_size" in kwargs: + del kwargs["kernel_size"] + if "strides" in kwargs: + del kwargs["strides"] + if "padding" in kwargs: + del kwargs["padding"] + if "data_format" in kwargs: + del kwargs["data_format"] + if "dilation_rate" in kwargs: + del kwargs["dilation_rate"] + if "depth_multiplier" in kwargs: + del kwargs["depth_multiplier"] + if "activation" in kwargs: + del kwargs["activation"] + if "use_bias" in kwargs: + del kwargs["use_bias"] + if "depthwise_initializer" in kwargs: + del kwargs["depthwise_initializer"] + if "pointwise_initializer" in kwargs: + del kwargs["pointwise_initializer"] + if "bias_initializer" in kwargs: + del kwargs["bias_initializer"] + if "depthwise_regularizer" in kwargs: + del kwargs["depthwise_regularizer"] + if "pointwise_regularizer" in kwargs: + del kwargs["pointwise_regularizer"] + if "activity_regularizer" in kwargs: + del kwargs["activity_regularizer"] + if "bias_regularizer" in kwargs: + del kwargs["bias_regularizer"] + if "depthwise_constraint" in kwargs: + del kwargs["depthwise_constraint"] + if "pointwise_constraint" in kwargs: + del kwargs["pointwise_constraint"] + if "bias_constraint" in kwargs: + del kwargs["bias_constraint"] + if "min_value" in kwargs: + del kwargs["min_value"] + if "max_value" in kwargs: + del kwargs["max_value"] + + return QSeparableConv2D( + name=fp32_layer.name, + filters=fp32_layer.filters, + kernel_size=fp32_layer.kernel_size, + strides=fp32_layer.strides, + padding=fp32_layer.padding, + data_format=fp32_layer.data_format, + dilation_rate=fp32_layer.dilation_rate, + depth_multiplier=fp32_layer.depth_multiplier, + activation=fp32_layer.activation, + use_bias=fp32_layer.use_bias, + depthwise_initializer=fp32_layer.depthwise_initializer, + pointwise_initializer=fp32_layer.pointwise_initializer, + bias_initializer=fp32_layer.bias_initializer, + depthwise_regularizer=fp32_layer.depthwise_regularizer, + pointwise_regularizer=fp32_layer.pointwise_regularizer, + bias_regularizer=fp32_layer.bias_regularizer, + activity_regularizer=fp32_layer.activity_regularizer, + depthwise_constraint=fp32_layer.depthwise_constraint, + pointwise_constraint=fp32_layer.pointwise_constraint, + bias_constraint=fp32_layer.bias_constraint, + min_value=fp32_layer.min_value, + max_value=fp32_layer.max_value, + **kwargs + ) From 3ec430ee25a60c96844baea2c08418ac99102484 Mon Sep 17 00:00:00 2001 From: zehao-intel Date: Tue, 26 Mar 2024 16:15:12 +0800 Subject: [PATCH 03/13] add check for conv2d in ut Signed-off-by: zehao-intel --- test/3x/tensorflow/keras/test_config.py | 62 +++++++++++++++++++++++-- 1 file changed, 58 insertions(+), 4 deletions(-) diff --git a/test/3x/tensorflow/keras/test_config.py b/test/3x/tensorflow/keras/test_config.py index 0e6d70f75f1..4d7ce12f72c 100644 --- a/test/3x/tensorflow/keras/test_config.py +++ b/test/3x/tensorflow/keras/test_config.py @@ -48,10 +48,10 @@ def build_model(): [ keras.layers.InputLayer(input_shape=(28, 28)), keras.layers.Reshape(target_shape=(28, 28, 1)), - keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation="relu"), + keras.layers.Conv2D(filters=12, kernel_size=(3, 3), activation="relu", name="conv2d"), keras.layers.MaxPooling2D(pool_size=(2, 2)), keras.layers.Flatten(), - keras.layers.Dense(10), + keras.layers.Dense(10, name="dense"), ] ) # Train the digit classification model @@ -128,9 +128,18 @@ def test_static_quant_from_dict_default(self): qmodel = quantize_model(fp32_model, get_default_static_quant_config(), calib_dataloader) self.assertIsNotNone(qmodel) + dense_checked = False + conv_checked = False for layer in qmodel.layers: if layer.name == "dense": + dense_checked = True self.assertEqual(layer.__class__.__name__, "QDense") + if layer.name == "conv2d": + conv_checked = True + self.assertEqual(layer.__class__.__name__, "QConv2D") + + self.assertEqual(dense_checked, True) + self.assertEqual(conv_checked, True) def test_static_quant_from_dict_beginner(self): logger.info("test_static_quant_from_dict_beginner") @@ -153,9 +162,18 @@ def test_static_quant_from_dict_beginner(self): qmodel = quantize_model(fp32_model, quant_config, calib_dataloader) self.assertIsNotNone(qmodel) + dense_checked = False + conv_checked = False for layer in qmodel.layers: if layer.name == "dense": + dense_checked = True self.assertEqual(layer.__class__.__name__, "QDense") + if layer.name == "conv2d": + conv_checked = True + self.assertEqual(layer.__class__.__name__, "QConv2D") + + self.assertEqual(dense_checked, True) + self.assertEqual(conv_checked, True) def test_static_quant_from_class_default(self): logger.info("test_static_quant_from_class_default") @@ -168,9 +186,18 @@ def test_static_quant_from_class_default(self): qmodel = quantize_model(fp32_model, quant_config, calib_dataloader) self.assertIsNotNone(qmodel) + dense_checked = False + conv_checked = False for layer in qmodel.layers: if layer.name == "dense": + dense_checked = True self.assertEqual(layer.__class__.__name__, "QDense") + if layer.name == "conv2d": + conv_checked = True + self.assertEqual(layer.__class__.__name__, "QConv2D") + + self.assertEqual(dense_checked, True) + self.assertEqual(conv_checked, True) def test_static_quant_from_class_beginner(self): logger.info("test_static_quant_from_class_beginner") @@ -190,9 +217,18 @@ def test_static_quant_from_class_beginner(self): qmodel = quantize_model(fp32_model, quant_config, calib_dataloader) self.assertIsNotNone(qmodel) + dense_checked = False + conv_checked = False for layer in qmodel.layers: if layer.name == "dense": + dense_checked = True self.assertEqual(layer.__class__.__name__, "QDense") + if layer.name == "conv2d": + conv_checked = True + self.assertEqual(layer.__class__.__name__, "QConv2D") + + self.assertEqual(dense_checked, True) + self.assertEqual(conv_checked, True) def test_static_quant_from_dict_advance(self): logger.info("test_static_quant_from_dict_advance") @@ -221,9 +257,18 @@ def test_static_quant_from_dict_advance(self): qmodel = quantize_model(fp32_model, quant_config, calib_dataloader) self.assertIsNotNone(qmodel) + dense_checked = False + conv_checked = False for layer in qmodel.layers: if layer.name == "dense": - self.assertNotEqual(layer.__class__.__name__, "QDense") + dense_checked = True + self.assertEqual(layer.__class__.__name__, "QDense") + if layer.name == "conv2d": + conv_checked = True + self.assertEqual(layer.__class__.__name__, "QConv2D") + + self.assertEqual(dense_checked, True) + self.assertEqual(conv_checked, True) def test_static_quant_from_class_advance(self): logger.info("test_static_quant_from_class_advance") @@ -249,9 +294,18 @@ def test_static_quant_from_class_advance(self): qmodel = quantize_model(fp32_model, quant_config, calib_dataloader) self.assertIsNotNone(qmodel) + dense_checked = False + conv_checked = False for layer in qmodel.layers: if layer.name == "dense": - self.assertNotEqual(layer.__class__.__name__, "QDense") + dense_checked = True + self.assertEqual(layer.__class__.__name__, "QDense") + if layer.name == "conv2d": + conv_checked = True + self.assertEqual(layer.__class__.__name__, "QConv2D") + + self.assertEqual(dense_checked, True) + self.assertEqual(conv_checked, True) def test_config_from_dict(self): logger.info("test_config_from_dict") From de7d3ca05c2959fbff0d338127dc6743f6d5d329 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 26 Mar 2024 08:17:20 +0000 Subject: [PATCH 04/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../algorithms/static_quant/keras.py | 140 +++++++++--------- .../tensorflow/keras/layers/dense.py | 4 +- .../keras/layers/depthwise_conv2d.py | 2 +- .../tensorflow/keras/layers/pool2d.py | 19 +-- .../keras/layers/separable_conv2d.py | 14 +- 5 files changed, 86 insertions(+), 93 deletions(-) diff --git a/neural_compressor/tensorflow/algorithms/static_quant/keras.py b/neural_compressor/tensorflow/algorithms/static_quant/keras.py index 1a76023f6f7..af86e8d640c 100644 --- a/neural_compressor/tensorflow/algorithms/static_quant/keras.py +++ b/neural_compressor/tensorflow/algorithms/static_quant/keras.py @@ -27,7 +27,6 @@ from neural_compressor.common import logger from neural_compressor.common.utils import DEFAULT_WORKSPACE - from neural_compressor.tensorflow.keras.layers import ( DeQuantize, FakeQuant, @@ -47,16 +46,16 @@ class KerasAdaptor: """The keras class of framework adaptor layer.""" supported_op = [ - "Conv2D", - "Dense", - "SeparableConv2D", - "DepthwiseConv2D", - "AveragePooling2D", - "MaxPooling2D", - "AvgPool2D", - "MaxPool2D", - ] - + "Conv2D", + "Dense", + "SeparableConv2D", + "DepthwiseConv2D", + "AveragePooling2D", + "MaxPooling2D", + "AvgPool2D", + "MaxPool2D", + ] + custom_layers = { "Quantize": Quantize, "DeQuantize": DeQuantize, @@ -92,8 +91,11 @@ def __init__(self, framework_specific_info): self.conv_format = {} self.fold_conv = [] - self.tmp_dir = (DEFAULT_WORKSPACE + "/tmp_model.keras") if version1_gte_version2(tf.__version__, "2.16.1") \ - else (DEFAULT_WORKSPACE + "/tmp_model") + self.tmp_dir = ( + (DEFAULT_WORKSPACE + "/tmp_model.keras") + if version1_gte_version2(tf.__version__, "2.16.1") + else (DEFAULT_WORKSPACE + "/tmp_model") + ) def _check_itex(self): """Check if the Intel® Extension for TensorFlow has been installed.""" @@ -129,7 +131,7 @@ def _check_quantize_mode(self, model): return "SCALED" def _set_weights(self, qmodel, layer_weights): - """Set fp32 weights to qmodel""" + """Set fp32 weights to qmodel.""" for qlayer in qmodel.layers: if qlayer.get_weights(): if qlayer.name in layer_weights: @@ -164,8 +166,7 @@ def _check_quantize_format(self, model): input_layer_names = input_layer_dict[layer.name] for input_layer_name in input_layer_names: check_layer = layer_name_mapping[input_layer_name] - if check_layer.__class__.__name__ == "Activation" \ - and check_layer.activation.__name__ in ["relu"]: + if check_layer.__class__.__name__ == "Activation" and check_layer.activation.__name__ in ["relu"]: self.conv_format[layer.name] = "u8" break @@ -271,9 +272,7 @@ def fuse_conv_bn(conv_weight, bn_weight, conv_type="Conv2D", eps=1.0e-5): bn_weight = self.bn_weights[layer.name] conv_type = fp32_layers[idx - 1].__class__.__name__ - self.layer_weights[conv_name] = fuse_conv_bn( - conv_weight, bn_weight, conv_type, layer.epsilon - ) + self.layer_weights[conv_name] = fuse_conv_bn(conv_weight, bn_weight, conv_type, layer.epsilon) self.fold_conv.append(conv_name) else: fuse_layers.append(layer) @@ -337,17 +336,11 @@ def quantize(self, quant_config, model, dataloader, iteration, q_func=None): fq_layers_dict = {} fq_output_layers = {} for idx, layer in enumerate(self.pre_optimized_model.layers): - if ( - layer.__class__.__name__ in self.supported_op - and layer.name in self.quantize_config["op_wise_config"] - ): + if layer.__class__.__name__ in self.supported_op and layer.name in self.quantize_config["op_wise_config"]: op_config = self.quantize_config["op_wise_config"][layer.name] mode = "per_channel" if op_config[0] else "per_tensor" fake_q_name = "fake_quant_" + str(idx) - fake_q_layer = FakeQuant(name=fake_q_name, - T=self.conv_format[layer.name], - mode="per_tensor" - ) + fake_q_layer = FakeQuant(name=fake_q_name, T=self.conv_format[layer.name], mode="per_tensor") fq_layers_dict[layer.name] = [fake_q_layer] fq_output_layers[fake_q_layer.name] = layer.name self.pre_optimized_model.save(self.tmp_dir) @@ -356,26 +349,23 @@ def quantize(self, quant_config, model, dataloader, iteration, q_func=None): calibration_model = fq_surgery.insert_quant_layers(fq_layers_dict) calibration_model = self._set_weights(calibration_model, self.layer_weights) - quantized_model = self._calibrate(calibration_model, - dataloader, - self.quantize_config["calib_iteration"], - fq_output_layers, - ) - + quantized_model = self._calibrate( + calibration_model, + dataloader, + self.quantize_config["calib_iteration"], + fq_output_layers, + ) + return quantized_model - def _calibrate(self, - model, - dataloader, - calib_interation, - fq_output_layers): - """Apply calibration. - + def _calibrate(self, model, dataloader, calib_interation, fq_output_layers): + """Apply calibration. + Args: model (tf.keras.Model): The model inserted with FakeQuant layers for calibration. dataloader(object): The calibration dataloader used to load quantization dataset. iteration(int): The iteration of calibration. - fq_output_layers (dict): A dict mapping from names of FakeQuant layers to + fq_output_layers (dict): A dict mapping from names of FakeQuant layers to names of their output layers. """ # run eagerly to fetch the numpy min/max @@ -391,7 +381,7 @@ def _calibrate(self, min_value = layer["config"]["min_value"] max_value = layer["config"]["max_value"] assert min_value < max_value, "The min value must be lower than the max value in quantization." - + if layer["config"]["name"] not in results: results[layer["config"]["name"]] = {"min": [min_value], "max": [max_value]} else: @@ -409,24 +399,21 @@ def _calibrate(self, max_value = max(results[layer.name]["max"]) quantize_layer = Quantize( - name = "quantize_" + str(qdq_layer_nums), - min_range = min_value, - max_range = max_value, - T = layer.T, + name="quantize_" + str(qdq_layer_nums), + min_range=min_value, + max_range=max_value, + T=layer.T, ) dequantize_layer = DeQuantize( - name = "dequantize_" + str(qdq_layer_nums), - min_range = min_value, - max_range = max_value, + name="dequantize_" + str(qdq_layer_nums), + min_range=min_value, + max_range=max_value, ) qdq_layer_nums += 1 output_layer_name = fq_output_layers[layer.name] qdq_layers_dict[output_layer_name] = [quantize_layer, dequantize_layer] - elif ( - layer.__class__.__name__ in self.supported_op - and layer.name in self.quantize_config["op_wise_config"] - ): + elif layer.__class__.__name__ in self.supported_op and layer.name in self.quantize_config["op_wise_config"]: # index 0 is weight, index 1 is bias q_layer_class = "Q" + layer.__class__.__name__ # for layers that have weights @@ -447,17 +434,17 @@ def _calibrate(self, layer.max_value = [10000] from neural_compressor.tensorflow.keras.layers import layer_initializer_dict - + q_layer = layer_initializer_dict[q_layer_class](layer) quantized_layers_dict[layer.name] = q_layer qdq_surgery = KerasSurgery(self.pre_optimized_model) quantized_model = qdq_surgery.insert_quant_layers(qdq_layers_dict, quantized_layers_dict) quantized_model = self._set_weights(quantized_model, self.layer_weights) - + quantized_model.save(self.tmp_dir) quantized_model = tf.keras.models.load_model(self.tmp_dir) - + return quantized_model @dump_elapsed_time(customized_msg="Model inference") @@ -594,7 +581,7 @@ def get_optype_wise_ability(self, quantizable_op_details): return res def tuning_cfg_to_fw(self, tuning_cfg): - """Parse tune_config and set framework variables. + """Parse tune_config and set framework variables. Args: tuning_cfg (dict): The dict of tunning config. @@ -815,9 +802,12 @@ def _create_input_dict(self, fuse_layers=None, conv_weights_keys=None): for node in layer._outbound_nodes: out_layer = node.outbound_layer layer_name = out_layer.name - if conv_weights_keys and out_layer.__class__.__name__ in ("BatchNormalization") and \ - out_layer.inbound_nodes[0].inbound_layers.name in conv_weights_keys: - layer_name = out_layer._outbound_nodes[0].outbound_layer.name + if ( + conv_weights_keys + and out_layer.__class__.__name__ in ("BatchNormalization") + and out_layer.inbound_nodes[0].inbound_layers.name in conv_weights_keys + ): + layer_name = out_layer._outbound_nodes[0].outbound_layer.name if layer_name not in input_layer_dict: input_layer_dict[layer_name] = [layer.name] else: @@ -833,16 +823,20 @@ def fuse_bn_layers(self, fuse_layers, conv_weights_keys): conv_weights_keys: The names of conv layers where BNs are going to be fused. """ self.input_layer_dict = self._create_input_dict(fuse_layers, conv_weights_keys) - has_input_layer = fuse_layers[0].__class__.__name__ == "InputLayer" - output_tensor_dict = {fuse_layers[0].name: self.model.input} if has_input_layer \ - else {"keras.Input": self.model.input} + has_input_layer = fuse_layers[0].__class__.__name__ == "InputLayer" + output_tensor_dict = ( + {fuse_layers[0].name: self.model.input} if has_input_layer else {"keras.Input": self.model.input} + ) for idx, layer in enumerate(fuse_layers): if idx == 0 and has_input_layer: continue - input_tensors = output_tensor_dict["keras.Input"] if idx == 0 \ - else [output_tensor_dict[input_layer] for input_layer in self.input_layer_dict[layer.name]] + input_tensors = ( + output_tensor_dict["keras.Input"] + if idx == 0 + else [output_tensor_dict[input_layer] for input_layer in self.input_layer_dict[layer.name]] + ) if len(input_tensors) == 1: input_tensors = input_tensors[0] @@ -855,7 +849,7 @@ def fuse_bn_layers(self, fuse_layers, conv_weights_keys): return tf.keras.models.Model(inputs=self.model.inputs, outputs=self.model_outputs) def insert_quant_layers(self, qdq_layer_dict, q_layer_dict=None): - """Insert FakeQuant or QDQ layers before the target layers and replace + """Insert FakeQuant or QDQ layers before the target layers and replace Keras layers to Quantized layers. Args: @@ -865,16 +859,20 @@ def insert_quant_layers(self, qdq_layer_dict, q_layer_dict=None): """ self.input_layer_dict = self._create_input_dict() layers = self.model.layers - has_input_layer = layers[0].__class__.__name__ == "InputLayer" - output_tensor_dict = {layers[0].name: layers[0](self.model.input)} if has_input_layer \ - else {"keras.Input": self.model.input} + has_input_layer = layers[0].__class__.__name__ == "InputLayer" + output_tensor_dict = ( + {layers[0].name: layers[0](self.model.input)} if has_input_layer else {"keras.Input": self.model.input} + ) for idx, layer in enumerate(layers): if idx == 0 and has_input_layer: continue - input_tensors = output_tensor_dict["keras.Input"] if idx == 0 \ - else [output_tensor_dict[input_layer] for input_layer in self.input_layer_dict[layer.name]] + input_tensors = ( + output_tensor_dict["keras.Input"] + if idx == 0 + else [output_tensor_dict[input_layer] for input_layer in self.input_layer_dict[layer.name]] + ) if len(input_tensors) == 1: input_tensors = input_tensors[0] diff --git a/neural_compressor/tensorflow/keras/layers/dense.py b/neural_compressor/tensorflow/keras/layers/dense.py index 007ad5dc999..61dfda2a2b8 100644 --- a/neural_compressor/tensorflow/keras/layers/dense.py +++ b/neural_compressor/tensorflow/keras/layers/dense.py @@ -62,9 +62,9 @@ def call(self, inputs): kernel_size = self.kernel.shape[-1] if not self.min_value: - self.min_value = [-10000]*kernel_size + self.min_value = [-10000] * kernel_size if not self.max_value: - self.max_value = [10000]*kernel_size + self.max_value = [10000] * kernel_size # add the Q/DQ here kernel, _, _ = quantization.quantize( diff --git a/neural_compressor/tensorflow/keras/layers/depthwise_conv2d.py b/neural_compressor/tensorflow/keras/layers/depthwise_conv2d.py index 9d86a375772..264b5551d70 100644 --- a/neural_compressor/tensorflow/keras/layers/depthwise_conv2d.py +++ b/neural_compressor/tensorflow/keras/layers/depthwise_conv2d.py @@ -128,7 +128,7 @@ def call(self, inputs): if self.activation is not None: return self.activation(outputs) - return + return else: class QDepthwiseConv2D(DepthwiseConv): def __init__( diff --git a/neural_compressor/tensorflow/keras/layers/pool2d.py b/neural_compressor/tensorflow/keras/layers/pool2d.py index 8fb7889e3c7..05a028ecc83 100644 --- a/neural_compressor/tensorflow/keras/layers/pool2d.py +++ b/neural_compressor/tensorflow/keras/layers/pool2d.py @@ -36,12 +36,7 @@ def __init__( **kwargs ): super(QAvgPool2D, self).__init__( - name=name, - pool_size=pool_size, - strides=strides, - padding=padding, - data_format=data_format, - **kwargs + name=name, pool_size=pool_size, strides=strides, padding=padding, data_format=data_format, **kwargs ) self.min_value = min_value self.max_value = max_value @@ -60,12 +55,7 @@ def __init__( **kwargs ): super(QMaxPool2D, self).__init__( - name=name, - pool_size=pool_size, - strides=strides, - padding=padding, - data_format=data_format, - **kwargs + name=name, pool_size=pool_size, strides=strides, padding=padding, data_format=data_format, **kwargs ) self.min_value = min_value self.max_value = max_value @@ -102,6 +92,7 @@ def initialize_int8_avgpool(fp32_layer): return q_layer + def initialize_int8_maxpool(fp32_layer): kwargs = fp32_layer.get_config() @@ -120,7 +111,7 @@ def initialize_int8_maxpool(fp32_layer): if "max_value" in kwargs: del kwargs["max_value"] - q_layer= QMaxPool2D( + q_layer = QMaxPool2D( name=fp32_layer.name, pool_size=fp32_layer.pool_size, strides=fp32_layer.strides, @@ -129,6 +120,6 @@ def initialize_int8_maxpool(fp32_layer): min_value=fp32_layer.min_value, max_value=fp32_layer.max_value, **kwargs - ) + ) return q_layer diff --git a/neural_compressor/tensorflow/keras/layers/separable_conv2d.py b/neural_compressor/tensorflow/keras/layers/separable_conv2d.py index 7cf7e496746..6a9ec1d9e75 100644 --- a/neural_compressor/tensorflow/keras/layers/separable_conv2d.py +++ b/neural_compressor/tensorflow/keras/layers/separable_conv2d.py @@ -34,6 +34,7 @@ from keras.utils import conv_utils # pylint: disable=E0401 if version1_gte_version2(tf.__version__, "2.16.1"): + class QSeparableConv2D(BaseSeparableConv): def __init__( self, @@ -58,7 +59,8 @@ def __init__( depthwise_constraint=None, pointwise_constraint=None, bias_constraint=None, - **kwargs ): + **kwargs + ): super().__init__( name=name, rank=2, @@ -93,9 +95,9 @@ def call(self, inputs): depthwise_kernel_size = self.depthwise_kernel.shape[-1] if not self.min_value: - self.min_value = [-10000]*depthwise_kernel_size + self.min_value = [-10000] * depthwise_kernel_size if not self.max_value: - self.max_value = [10000]*depthwise_kernel_size + self.max_value = [10000] * depthwise_kernel_size # TODO it's ugly that we can't get the point_wise min/max here depthwise_kernel, _, _ = quantization.quantize( @@ -130,7 +132,9 @@ def call(self, inputs): if self.activation is not None: return self.activation(outputs) return outputs + else: + class QSeparableConv2D(SeparableConv): def __init__( self, @@ -190,9 +194,9 @@ def call(self, inputs): depthwise_kernel_size = self.depthwise_kernel.shape[-1] if not self.min_value: - self.min_value = [-10000]*depthwise_kernel_size + self.min_value = [-10000] * depthwise_kernel_size if not self.max_value: - self.max_value = [10000]*depthwise_kernel_size + self.max_value = [10000] * depthwise_kernel_size # TODO it's ugly that we can't get the point_wise min/max here depthwise_kernel, _, _ = quantization.quantize( From 47f6c0f3ac99936e2cd30465ff7929901f3ff9ef Mon Sep 17 00:00:00 2001 From: zehao-intel Date: Tue, 26 Mar 2024 17:21:55 +0800 Subject: [PATCH 05/13] remove merge notes Signed-off-by: zehao-intel --- .../tensorflow/keras/layers/conv2d.py | 6 - .../keras/layers/depthwise_conv2d.py | 134 +----------------- .../tensorflow/keras/quantization/config.py | 2 +- 3 files changed, 2 insertions(+), 140 deletions(-) diff --git a/neural_compressor/tensorflow/keras/layers/conv2d.py b/neural_compressor/tensorflow/keras/layers/conv2d.py index d5eda05c804..84e7b3b01ef 100644 --- a/neural_compressor/tensorflow/keras/layers/conv2d.py +++ b/neural_compressor/tensorflow/keras/layers/conv2d.py @@ -24,13 +24,7 @@ from neural_compressor.tensorflow.utils import version1_gte_version2 if version1_gte_version2(tf.__version__, "2.16.1"): -<<<<<<< HEAD from keras.src.layers.convolutional.base_conv import BaseConv as Conv # pylint: disable=E0401 -======= - from keras.src.layers.convolutional.base_conv import BaseConv # pylint: disable=E0401 - - Conv = BaseConv ->>>>>>> master elif version1_gte_version2(tf.__version__, "2.13.0"): from keras.src.layers.convolutional.base_conv import Conv # pylint: disable=E0401 else: diff --git a/neural_compressor/tensorflow/keras/layers/depthwise_conv2d.py b/neural_compressor/tensorflow/keras/layers/depthwise_conv2d.py index 264b5551d70..98ee101ff77 100644 --- a/neural_compressor/tensorflow/keras/layers/depthwise_conv2d.py +++ b/neural_compressor/tensorflow/keras/layers/depthwise_conv2d.py @@ -34,20 +34,10 @@ from keras.utils import conv_utils, tf_utils # pylint: disable=E0401 if version1_gte_version2(tf.__version__, "2.16.1"): - -<<<<<<< HEAD -if version1_gte_version2(tf.__version__, "2.16.1"): -======= ->>>>>>> master class QDepthwiseConv2D(BaseDepthwiseConv): def __init__( self, kernel_size, -<<<<<<< HEAD -======= - min_value, - max_value, ->>>>>>> master strides=(1, 1), padding="valid", depth_multiplier=1, @@ -62,7 +52,6 @@ def __init__( activity_regularizer=None, depthwise_constraint=None, bias_constraint=None, -<<<<<<< HEAD min_value=None, max_value=None, **kwargs ): @@ -128,7 +117,7 @@ def call(self, inputs): if self.activation is not None: return self.activation(outputs) - return + return else: class QDepthwiseConv2D(DepthwiseConv): def __init__( @@ -150,8 +139,6 @@ def __init__( bias_constraint=None, min_value=None, max_value=None, -======= ->>>>>>> master **kwargs ): super().__init__( @@ -173,7 +160,6 @@ def __init__( bias_constraint=bias_constraint, **kwargs ) -<<<<<<< HEAD self.min_value = min_value self.max_value = max_value @@ -213,121 +199,6 @@ def call(self, inputs): return outputs -======= - self.min_value = json.loads(min_value) - self.max_value = json.loads(max_value) - - def call(self, inputs): - # add the Q/DQ here - kernel, _, _ = quantization.quantize( - self.depthwise_kernel, self.min_value, self.max_value, tf.qint8, axis=3, mode="SCALED" - ) - kernel = quantization.dequantize( - kernel, - self.min_value, - self.max_value, - axis=3, - mode="SCALED", - ) - - input_channel = self._get_input_channel(inputs.shape) - outputs = ops.depthwise_conv( - inputs, - self.kernel, - strides=self.strides, - padding=self.padding, - dilation_rate=self.dilation_rate, - data_format=self.data_format, - ) - - if self.use_bias: - if self.data_format == "channels_last": - bias_shape = (1,) * (self.rank + 1) + (self.depth_multiplier * input_channel,) - else: - bias_shape = (1, self.depth_multiplier * input_channel) + (1,) * self.rank - bias = ops.reshape(self.bias, bias_shape) - outputs += bias - - if self.activation is not None: - return self.activation(outputs) - return outputs - -else: - - class QDepthwiseConv2D(DepthwiseConv): - def __init__( - self, - kernel_size, - min_value, - max_value, - strides=(1, 1), - padding="valid", - depth_multiplier=1, - data_format=None, - dilation_rate=(1, 1), - activation=None, - use_bias=True, - depthwise_initializer="glorot_uniform", - bias_initializer="zeros", - depthwise_regularizer=None, - bias_regularizer=None, - activity_regularizer=None, - depthwise_constraint=None, - bias_constraint=None, - **kwargs - ): - super().__init__( - 2, - kernel_size=kernel_size, - strides=strides, - padding=padding, - depth_multiplier=depth_multiplier, - data_format=data_format, - dilation_rate=dilation_rate, - activation=activation, - use_bias=use_bias, - depthwise_initializer=depthwise_initializer, - bias_initializer=bias_initializer, - depthwise_regularizer=depthwise_regularizer, - bias_regularizer=bias_regularizer, - activity_regularizer=activity_regularizer, - depthwise_constraint=depthwise_constraint, - bias_constraint=bias_constraint, - **kwargs - ) - self.min_value = json.loads(min_value) - self.max_value = json.loads(max_value) - - def call(self, inputs): - # add the Q/DQ here - kernel, _, _ = quantization.quantize( - self.depthwise_kernel, self.min_value, self.max_value, tf.qint8, axis=3, mode="SCALED" - ) - kernel = quantization.dequantize( - kernel, - self.min_value, - self.max_value, - axis=3, - mode="SCALED", - ) - outputs = tf.keras.backend.depthwise_conv2d( - inputs, - kernel, - strides=self.strides, - padding=self.padding, - data_format=self.data_format, - dilation_rate=self.dilation_rate, - ) - - if self.use_bias: - outputs = tf.keras.backend.bias_add(outputs, self.bias, data_format=self.data_format) - - if self.activation is not None: - return self.activation(outputs) - - return outputs - ->>>>>>> master @classmethod def from_config(cls, config): return cls(**config) @@ -361,7 +232,6 @@ def compute_output_shape(self, input_shape): return (input_shape[0], out_filters, rows, cols) elif self.data_format == "channels_last": return (input_shape[0], rows, cols, out_filters) -<<<<<<< HEAD def initialize_int8_depthwise_conv2d(fp32_layer): @@ -426,5 +296,3 @@ def initialize_int8_depthwise_conv2d(fp32_layer): max_value=fp32_layer.max_value, **kwargs ) -======= ->>>>>>> master diff --git a/neural_compressor/tensorflow/keras/quantization/config.py b/neural_compressor/tensorflow/keras/quantization/config.py index a46a7375ca9..ae532dc63c4 100644 --- a/neural_compressor/tensorflow/keras/quantization/config.py +++ b/neural_compressor/tensorflow/keras/quantization/config.py @@ -114,7 +114,7 @@ def register_supported_configs(cls) -> List[OperatorConfig]: def get_model_info(model) -> List[Tuple[str, Callable]]: white_list = [ "Dense", - "Conv2d", + "Conv2D", "DepthwiseConv2D", "SeparableConv2D", "AvgPool2D", From 44a1fcec98964c039643e891815ff33a02c4a560 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 26 Mar 2024 09:23:33 +0000 Subject: [PATCH 06/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../tensorflow/keras/layers/conv2d.py | 4 ++-- .../tensorflow/keras/layers/depthwise_conv2d.py | 16 ++++++++++------ 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/neural_compressor/tensorflow/keras/layers/conv2d.py b/neural_compressor/tensorflow/keras/layers/conv2d.py index 84e7b3b01ef..d7f5c8bc698 100644 --- a/neural_compressor/tensorflow/keras/layers/conv2d.py +++ b/neural_compressor/tensorflow/keras/layers/conv2d.py @@ -83,9 +83,9 @@ def call(self, inputs): kernel_size = self.kernel.shape[-1] if not self.min_value: - self.min_value = [-10000]*kernel_size + self.min_value = [-10000] * kernel_size if not self.max_value: - self.max_value = [10000]*kernel_size + self.max_value = [10000] * kernel_size # add the Q/DQ here kernel, _, _ = quantization.quantize( diff --git a/neural_compressor/tensorflow/keras/layers/depthwise_conv2d.py b/neural_compressor/tensorflow/keras/layers/depthwise_conv2d.py index 98ee101ff77..33ad4802e52 100644 --- a/neural_compressor/tensorflow/keras/layers/depthwise_conv2d.py +++ b/neural_compressor/tensorflow/keras/layers/depthwise_conv2d.py @@ -34,6 +34,7 @@ from keras.utils import conv_utils, tf_utils # pylint: disable=E0401 if version1_gte_version2(tf.__version__, "2.16.1"): + class QDepthwiseConv2D(BaseDepthwiseConv): def __init__( self, @@ -54,7 +55,8 @@ def __init__( bias_constraint=None, min_value=None, max_value=None, - **kwargs ): + **kwargs + ): super().__init__( 2, kernel_size=kernel_size, @@ -81,9 +83,9 @@ def call(self, inputs): kernel_size = self.kernel.shape[-1] if not self.min_value: - self.min_value = [-10000]*kernel_size + self.min_value = [-10000] * kernel_size if not self.max_value: - self.max_value = [10000]*kernel_size + self.max_value = [10000] * kernel_size # add the Q/DQ here kernel, _, _ = quantization.quantize( @@ -117,8 +119,10 @@ def call(self, inputs): if self.activation is not None: return self.activation(outputs) - return + return + else: + class QDepthwiseConv2D(DepthwiseConv): def __init__( self, @@ -167,9 +171,9 @@ def call(self, inputs): depthwise_kernel_size = self.depthwise_kernel.shape[-1] if not self.min_value: - self.min_value = [-10000]*depthwise_kernel_size + self.min_value = [-10000] * depthwise_kernel_size if not self.max_value: - self.max_value = [10000]*depthwise_kernel_size + self.max_value = [10000] * depthwise_kernel_size # add the Q/DQ here kernel, _, _ = quantization.quantize( From 4c49f7a0be6db5d64b9e9e10e8267c3ed5894365 Mon Sep 17 00:00:00 2001 From: zehao-intel Date: Thu, 28 Mar 2024 17:41:32 +0800 Subject: [PATCH 07/13] remove support for keras3 Signed-off-by: zehao-intel --- .../algorithms/static_quant/keras.py | 114 +++---- .../tensorflow/keras/layers/conv2d.py | 4 +- .../keras/layers/depthwise_conv2d.py | 314 +++++++----------- .../keras/layers/separable_conv2d.py | 306 ++++++----------- test/3x/tensorflow/keras/test_config.py | 14 +- .../quantization/test_smooth_quant.py | 6 +- test/3x/tensorflow/test_autotune.py | 2 +- 7 files changed, 282 insertions(+), 478 deletions(-) diff --git a/neural_compressor/tensorflow/algorithms/static_quant/keras.py b/neural_compressor/tensorflow/algorithms/static_quant/keras.py index af86e8d640c..f4d7b6f5263 100644 --- a/neural_compressor/tensorflow/algorithms/static_quant/keras.py +++ b/neural_compressor/tensorflow/algorithms/static_quant/keras.py @@ -39,7 +39,7 @@ Quantize, ) from neural_compressor.tensorflow.quantization.config import StaticQuantConfig -from neural_compressor.tensorflow.utils import deep_get, dump_elapsed_time, version1_gte_version2 +from neural_compressor.tensorflow.utils import deep_get, dump_elapsed_time class KerasAdaptor: @@ -91,11 +91,9 @@ def __init__(self, framework_specific_info): self.conv_format = {} self.fold_conv = [] - self.tmp_dir = ( - (DEFAULT_WORKSPACE + "/tmp_model.keras") - if version1_gte_version2(tf.__version__, "2.16.1") - else (DEFAULT_WORKSPACE + "/tmp_model") - ) + if not os.path.exists(DEFAULT_WORKSPACE): + os.mkdir(DEFAULT_WORKSPACE) + self.tmp_dir = DEFAULT_WORKSPACE + "tmp_model" def _check_itex(self): """Check if the Intel® Extension for TensorFlow has been installed.""" @@ -151,6 +149,7 @@ def _check_quantize_format(self, model): """The function that checks format for conv ops.""" input_layer_dict = {} layer_name_mapping = {} + for layer in model.layers: layer_name_mapping[layer.name] = layer for node in layer._outbound_nodes: @@ -160,6 +159,7 @@ def _check_quantize_format(self, model): else: input_layer_dict[layer_name].append(layer.name) + for layer in model.layers: if layer.__class__.__name__ in self.supported_op: self.conv_format[layer.name] = "s8" @@ -221,43 +221,38 @@ def fuse_conv_bn(conv_weight, bn_weight, conv_type="Conv2D", eps=1.0e-5): fuse_layers = [] for idx, layer in enumerate(fp32_layers): - if hasattr(layer, "inbound_nodes"): + if hasattr(layer, "_inbound_nodes"): if layer.__class__.__name__ in ("BatchNormalization"): - bn_inbound_node = layer.inbound_nodes[0] - inbound_layer = bn_inbound_node.inbound_layers - if inbound_layer.name in self.conv_weights.keys(): - conv_layer = inbound_layer - conv_weight = self.conv_weights[conv_layer.name] - bn_weight = self.bn_weights[layer.name] - - self.layer_weights[conv_layer.name] = fuse_conv_bn( - conv_weight, bn_weight, conv_layer.__class__.__name__, layer.epsilon - ) - self.fold_conv.append(conv_layer.name) - else: - fuse_layers.append(layer) - elif len(layer.inbound_nodes): + for bn_inbound_node in layer._inbound_nodes: + inbound_layer = bn_inbound_node.inbound_layers + if inbound_layer.name in self.conv_weights.keys(): + conv_layer = inbound_layer + conv_weight = self.conv_weights[conv_layer.name] + bn_weight = self.bn_weights[layer.name] + + self.layer_weights[conv_layer.name] = fuse_conv_bn( + conv_weight, bn_weight, conv_layer.__class__.__name__, layer.epsilon + ) + self.fold_conv.append(conv_layer.name) + else: + fuse_layers.append(layer) + elif len(layer._inbound_nodes): new_bound_nodes = [] # OpLambda node will have different bound node if layer.__class__.__name__ in ("TFOpLambda", "SlicingOpLambda"): fuse_layers.append(layer) else: - for bound_node in layer.inbound_nodes: + for bound_node in layer._inbound_nodes: inbound_layer = bound_node.inbound_layers - - if not isinstance(inbound_layer, list) and inbound_layer in self.bn_weights.keys(): - bn_inbound_node = inbound_layer.inbound_nodes[0] - bn_inbound_layer = bn_inbound_node.inbound_layers - if bn_inbound_layer.name in self.conv_weights.keys(): + if not isinstance(inbound_layer, list) and inbound_layer.name in self.bn_weights.keys() \ + and inbound_layer._inbound_nodes[0].inbound_layers.name in self.conv_weights.keys(): new_bound_nodes.append(bn_inbound_node) - else: - new_bound_nodes.append(bound_node) else: new_bound_nodes.append(bound_node) - for idx, bound_node in enumerate(new_bound_nodes): - layer.inbound_nodes[idx] = new_bound_nodes[idx] - + layer._inbound_nodes.clear() + for bound_node in new_bound_nodes: + layer._inbound_nodes.append(bound_node) fuse_layers.append(layer) else: fuse_layers.append(layer) @@ -285,15 +280,18 @@ def fuse_conv_bn(conv_weight, bn_weight, conv_type="Conv2D", eps=1.0e-5): conv_config = layer.get_config() conv_config["use_bias"] = True conv_layer = type(layer).from_config(conv_config) - conv_layer._outbound_nodes.append(layer._outbound_nodes[0]) + for node in layer._outbound_nodes: + conv_layer._outbound_nodes.append(node) fuse_layers[idx] = conv_layer bn_surgery = KerasSurgery(model) - fused_model = bn_surgery.fuse_bn_layers(fuse_layers, self.conv_weights.keys()) - fused_model = self._set_weights(fused_model, self.layer_weights) - fused_model.save(self.tmp_dir) + bn_fused_model = bn_surgery.fuse_bn_layers(fuse_layers, self.conv_weights.keys()) + bn_fused_model = self._set_weights(bn_fused_model, self.layer_weights) - return tf.keras.models.load_model(self.tmp_dir) + bn_fused_model.save(self.tmp_dir) + bn_fused_model = tf.keras.models.load_model(self.tmp_dir) + + return bn_fused_model @dump_elapsed_time("Pass quantize model") def quantize(self, quant_config, model, dataloader, iteration, q_func=None): @@ -801,17 +799,19 @@ def _create_input_dict(self, fuse_layers=None, conv_weights_keys=None): for layer in layers: for node in layer._outbound_nodes: out_layer = node.outbound_layer - layer_name = out_layer.name + out_layer_names = [out_layer.name] if ( conv_weights_keys and out_layer.__class__.__name__ in ("BatchNormalization") - and out_layer.inbound_nodes[0].inbound_layers.name in conv_weights_keys + and layer.name in conv_weights_keys ): - layer_name = out_layer._outbound_nodes[0].outbound_layer.name - if layer_name not in input_layer_dict: - input_layer_dict[layer_name] = [layer.name] - else: - input_layer_dict[layer_name].append(layer.name) + out_layer_names = [node.outbound_layer.name for node in out_layer._outbound_nodes] + + for out_layer_name in out_layer_names: + if out_layer_name not in input_layer_dict: + input_layer_dict[out_layer_name] = [layer.name] + else: + input_layer_dict[out_layer_name].append(layer.name) return input_layer_dict @@ -823,13 +823,11 @@ def fuse_bn_layers(self, fuse_layers, conv_weights_keys): conv_weights_keys: The names of conv layers where BNs are going to be fused. """ self.input_layer_dict = self._create_input_dict(fuse_layers, conv_weights_keys) - has_input_layer = fuse_layers[0].__class__.__name__ == "InputLayer" - output_tensor_dict = ( - {fuse_layers[0].name: self.model.input} if has_input_layer else {"keras.Input": self.model.input} - ) + output_tensor_dict = {"keras.Input": self.model.input} for idx, layer in enumerate(fuse_layers): - if idx == 0 and has_input_layer: + if layer.__class__.__name__ == "InputLayer": + output_tensor_dict[layer.name] = output_tensor_dict["keras.Input"] continue input_tensors = ( @@ -837,7 +835,8 @@ def fuse_bn_layers(self, fuse_layers, conv_weights_keys): if idx == 0 else [output_tensor_dict[input_layer] for input_layer in self.input_layer_dict[layer.name]] ) - if len(input_tensors) == 1: + + while isinstance(input_tensors, list) and len(input_tensors) == 1: input_tensors = input_tensors[0] x = layer(input_tensors) @@ -845,7 +844,7 @@ def fuse_bn_layers(self, fuse_layers, conv_weights_keys): output_tensor_dict[layer.name] = x if layer.name in self.model.output_names: self.model_outputs.append(x) - + return tf.keras.models.Model(inputs=self.model.inputs, outputs=self.model_outputs) def insert_quant_layers(self, qdq_layer_dict, q_layer_dict=None): @@ -858,14 +857,11 @@ def insert_quant_layers(self, qdq_layer_dict, q_layer_dict=None): q_layer_dict: The dict mapping from layers to be replacement to the quantized layers. """ self.input_layer_dict = self._create_input_dict() - layers = self.model.layers - has_input_layer = layers[0].__class__.__name__ == "InputLayer" - output_tensor_dict = ( - {layers[0].name: layers[0](self.model.input)} if has_input_layer else {"keras.Input": self.model.input} - ) + output_tensor_dict = {"keras.Input": self.model.input} - for idx, layer in enumerate(layers): - if idx == 0 and has_input_layer: + for idx, layer in enumerate(self.model.layers): + if layer.__class__.__name__ == "InputLayer": + output_tensor_dict[layer.name] = output_tensor_dict["keras.Input"] continue input_tensors = ( @@ -873,7 +869,7 @@ def insert_quant_layers(self, qdq_layer_dict, q_layer_dict=None): if idx == 0 else [output_tensor_dict[input_layer] for input_layer in self.input_layer_dict[layer.name]] ) - if len(input_tensors) == 1: + while isinstance(input_tensors, list) and len(input_tensors) == 1: input_tensors = input_tensors[0] if layer.name in qdq_layer_dict: diff --git a/neural_compressor/tensorflow/keras/layers/conv2d.py b/neural_compressor/tensorflow/keras/layers/conv2d.py index 84e7b3b01ef..b815f4f397e 100644 --- a/neural_compressor/tensorflow/keras/layers/conv2d.py +++ b/neural_compressor/tensorflow/keras/layers/conv2d.py @@ -23,9 +23,7 @@ from neural_compressor.tensorflow.utils import version1_gte_version2 -if version1_gte_version2(tf.__version__, "2.16.1"): - from keras.src.layers.convolutional.base_conv import BaseConv as Conv # pylint: disable=E0401 -elif version1_gte_version2(tf.__version__, "2.13.0"): +if version1_gte_version2(tf.__version__, "2.13.0"): from keras.src.layers.convolutional.base_conv import Conv # pylint: disable=E0401 else: from keras.layers.convolutional.base_conv import Conv # pylint: disable=E0401 diff --git a/neural_compressor/tensorflow/keras/layers/depthwise_conv2d.py b/neural_compressor/tensorflow/keras/layers/depthwise_conv2d.py index 98ee101ff77..431ce68f762 100644 --- a/neural_compressor/tensorflow/keras/layers/depthwise_conv2d.py +++ b/neural_compressor/tensorflow/keras/layers/depthwise_conv2d.py @@ -23,215 +23,127 @@ from neural_compressor.tensorflow.utils import version1_gte_version2 -if version1_gte_version2(tf.__version__, "2.16.1"): - from keras.src import ops - from keras.src.layers.convolutional.base_depthwise_conv import BaseDepthwiseConv # pylint: disable=E0401 -elif version1_gte_version2(tf.__version__, "2.13.0"): +if version1_gte_version2(tf.__version__, "2.13.0"): from keras.src.layers.convolutional.base_depthwise_conv import DepthwiseConv # pylint: disable=E0401 from keras.src.utils import conv_utils, tf_utils # pylint: disable=E0401 else: from keras.layers.convolutional.base_depthwise_conv import DepthwiseConv # pylint: disable=E0401 from keras.utils import conv_utils, tf_utils # pylint: disable=E0401 -if version1_gte_version2(tf.__version__, "2.16.1"): - class QDepthwiseConv2D(BaseDepthwiseConv): - def __init__( - self, - kernel_size, - strides=(1, 1), - padding="valid", - depth_multiplier=1, - data_format=None, - dilation_rate=(1, 1), - activation=None, - use_bias=True, - depthwise_initializer="glorot_uniform", - bias_initializer="zeros", - depthwise_regularizer=None, - bias_regularizer=None, - activity_regularizer=None, - depthwise_constraint=None, - bias_constraint=None, - min_value=None, - max_value=None, - **kwargs ): - super().__init__( - 2, - kernel_size=kernel_size, - strides=strides, - padding=padding, - depth_multiplier=depth_multiplier, - data_format=data_format, - dilation_rate=dilation_rate, - activation=activation, - use_bias=use_bias, - depthwise_initializer=depthwise_initializer, - bias_initializer=bias_initializer, - depthwise_regularizer=depthwise_regularizer, - bias_regularizer=bias_regularizer, - activity_regularizer=activity_regularizer, - depthwise_constraint=depthwise_constraint, - bias_constraint=bias_constraint, - **kwargs - ) - self.min_value = min_value - self.max_value = max_value - def call(self, inputs): - kernel_size = self.kernel.shape[-1] - - if not self.min_value: - self.min_value = [-10000]*kernel_size - if not self.max_value: - self.max_value = [10000]*kernel_size - - # add the Q/DQ here - kernel, _, _ = quantization.quantize( - self.depthwise_kernel, self.min_value, self.max_value, tf.qint8, axis=3, mode="SCALED" - ) - kernel = quantization.dequantize( - kernel, - self.min_value, - self.max_value, - axis=3, - mode="SCALED", - ) - - input_channel = self._get_input_channel(inputs.shape) - outputs = ops.depthwise_conv( - inputs, - self.kernel, - strides=self.strides, - padding=self.padding, - dilation_rate=self.dilation_rate, - data_format=self.data_format, - ) - - if self.use_bias: - if self.data_format == "channels_last": - bias_shape = (1,) * (self.rank + 1) + (self.depth_multiplier * input_channel,) - else: - bias_shape = (1, self.depth_multiplier * input_channel) + (1,) * self.rank - bias = ops.reshape(self.bias, bias_shape) - outputs += bias - - if self.activation is not None: - return self.activation(outputs) - return -else: - class QDepthwiseConv2D(DepthwiseConv): - def __init__( - self, - kernel_size, - strides=(1, 1), - padding="valid", - depth_multiplier=1, - data_format=None, - dilation_rate=(1, 1), - activation=None, - use_bias=True, - depthwise_initializer="glorot_uniform", - bias_initializer="zeros", - depthwise_regularizer=None, - bias_regularizer=None, - activity_regularizer=None, - depthwise_constraint=None, - bias_constraint=None, - min_value=None, - max_value=None, +class QDepthwiseConv2D(DepthwiseConv): + def __init__( + self, + kernel_size, + strides=(1, 1), + padding="valid", + depth_multiplier=1, + data_format=None, + dilation_rate=(1, 1), + activation=None, + use_bias=True, + depthwise_initializer="glorot_uniform", + bias_initializer="zeros", + depthwise_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + depthwise_constraint=None, + bias_constraint=None, + min_value=None, + max_value=None, + **kwargs + ): + super().__init__( + 2, + kernel_size=kernel_size, + strides=strides, + padding=padding, + depth_multiplier=depth_multiplier, + data_format=data_format, + dilation_rate=dilation_rate, + activation=activation, + use_bias=use_bias, + depthwise_initializer=depthwise_initializer, + bias_initializer=bias_initializer, + depthwise_regularizer=depthwise_regularizer, + bias_regularizer=bias_regularizer, + activity_regularizer=activity_regularizer, + depthwise_constraint=depthwise_constraint, + bias_constraint=bias_constraint, **kwargs - ): - super().__init__( - 2, - kernel_size=kernel_size, - strides=strides, - padding=padding, - depth_multiplier=depth_multiplier, - data_format=data_format, - dilation_rate=dilation_rate, - activation=activation, - use_bias=use_bias, - depthwise_initializer=depthwise_initializer, - bias_initializer=bias_initializer, - depthwise_regularizer=depthwise_regularizer, - bias_regularizer=bias_regularizer, - activity_regularizer=activity_regularizer, - depthwise_constraint=depthwise_constraint, - bias_constraint=bias_constraint, - **kwargs - ) - self.min_value = min_value - self.max_value = max_value - - def call(self, inputs): - depthwise_kernel_size = self.depthwise_kernel.shape[-1] - - if not self.min_value: - self.min_value = [-10000]*depthwise_kernel_size - if not self.max_value: - self.max_value = [10000]*depthwise_kernel_size - - # add the Q/DQ here - kernel, _, _ = quantization.quantize( - self.depthwise_kernel, self.min_value, self.max_value, tf.qint8, axis=3, mode="SCALED" - ) - kernel = quantization.dequantize( - kernel, - self.min_value, - self.max_value, - axis=3, - mode="SCALED", - ) - outputs = tf.keras.backend.depthwise_conv2d( - inputs, - kernel, - strides=self.strides, - padding=self.padding, - data_format=self.data_format, - dilation_rate=self.dilation_rate, - ) - - if self.use_bias: - outputs = tf.keras.backend.bias_add(outputs, self.bias, data_format=self.data_format) - - if self.activation is not None: - return self.activation(outputs) - - return outputs - - @classmethod - def from_config(cls, config): - return cls(**config) - - @tf_utils.shape_type_conversion - def compute_output_shape(self, input_shape): - if self.data_format == "channels_first": - rows = input_shape[2] - cols = input_shape[3] - out_filters = input_shape[1] * self.depth_multiplier - elif self.data_format == "channels_last": - rows = input_shape[1] - cols = input_shape[2] - out_filters = input_shape[3] * self.depth_multiplier - - rows = conv_utils.conv_output_length( - rows, - self.kernel_size[0], - self.padding, - self.strides[0], - self.dilation_rate[0], - ) - cols = conv_utils.conv_output_length( - cols, - self.kernel_size[1], - self.padding, - self.strides[1], - self.dilation_rate[1], - ) - if self.data_format == "channels_first": - return (input_shape[0], out_filters, rows, cols) - elif self.data_format == "channels_last": - return (input_shape[0], rows, cols, out_filters) + ) + self.min_value = min_value + self.max_value = max_value + + def call(self, inputs): + depthwise_kernel_size = self.depthwise_kernel.shape[-1] + + if not self.min_value: + self.min_value = [-10000]*depthwise_kernel_size + if not self.max_value: + self.max_value = [10000]*depthwise_kernel_size + + # add the Q/DQ here + kernel, _, _ = quantization.quantize( + self.depthwise_kernel, self.min_value, self.max_value, tf.qint8, axis=3, mode="SCALED" + ) + kernel = quantization.dequantize( + kernel, + self.min_value, + self.max_value, + axis=3, + mode="SCALED", + ) + outputs = tf.keras.backend.depthwise_conv2d( + inputs, + kernel, + strides=self.strides, + padding=self.padding, + data_format=self.data_format, + dilation_rate=self.dilation_rate, + ) + + if self.use_bias: + outputs = tf.keras.backend.bias_add(outputs, self.bias, data_format=self.data_format) + + if self.activation is not None: + return self.activation(outputs) + + return outputs + + @classmethod + def from_config(cls, config): + return cls(**config) + + @tf_utils.shape_type_conversion + def compute_output_shape(self, input_shape): + if self.data_format == "channels_first": + rows = input_shape[2] + cols = input_shape[3] + out_filters = input_shape[1] * self.depth_multiplier + elif self.data_format == "channels_last": + rows = input_shape[1] + cols = input_shape[2] + out_filters = input_shape[3] * self.depth_multiplier + + rows = conv_utils.conv_output_length( + rows, + self.kernel_size[0], + self.padding, + self.strides[0], + self.dilation_rate[0], + ) + cols = conv_utils.conv_output_length( + cols, + self.kernel_size[1], + self.padding, + self.strides[1], + self.dilation_rate[1], + ) + if self.data_format == "channels_first": + return (input_shape[0], out_filters, rows, cols) + elif self.data_format == "channels_last": + return (input_shape[0], rows, cols, out_filters) def initialize_int8_depthwise_conv2d(fp32_layer): diff --git a/neural_compressor/tensorflow/keras/layers/separable_conv2d.py b/neural_compressor/tensorflow/keras/layers/separable_conv2d.py index 6a9ec1d9e75..7df66d9db49 100644 --- a/neural_compressor/tensorflow/keras/layers/separable_conv2d.py +++ b/neural_compressor/tensorflow/keras/layers/separable_conv2d.py @@ -23,219 +23,115 @@ from neural_compressor.tensorflow.utils import version1_gte_version2 -if version1_gte_version2(tf.__version__, "2.16.1"): - from keras.src import ops - from keras.src.layers.convolutional.base_separable_conv import BaseSeparableConv # pylint: disable=E0401 -elif version1_gte_version2(tf.__version__, "2.13.0"): +if version1_gte_version2(tf.__version__, "2.13.0"): from keras.src.layers.convolutional.base_separable_conv import SeparableConv # pylint: disable=E0401 from keras.src.utils import conv_utils # pylint: disable=E0401 else: from keras.layers.convolutional.base_separable_conv import SeparableConv # pylint: disable=E0401 from keras.utils import conv_utils # pylint: disable=E0401 -if version1_gte_version2(tf.__version__, "2.16.1"): - class QSeparableConv2D(BaseSeparableConv): - def __init__( - self, - filters, - kernel_size, - min_value, - max_value, - strides=(1, 1), - padding="valid", - data_format=None, - dilation_rate=(1, 1), - depth_multiplier=1, - activation=None, - use_bias=True, - depthwise_initializer="glorot_uniform", - pointwise_initializer="glorot_uniform", - bias_initializer="zeros", - depthwise_regularizer=None, - pointwise_regularizer=None, - bias_regularizer=None, - activity_regularizer=None, - depthwise_constraint=None, - pointwise_constraint=None, - bias_constraint=None, - **kwargs - ): - super().__init__( - name=name, - rank=2, - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - dilation_rate=dilation_rate, - depth_multiplier=depth_multiplier, - activation=activations.get(activation), - use_bias=use_bias, - depthwise_initializer=initializers.get(depthwise_initializer), - pointwise_initializer=initializers.get(pointwise_initializer), - bias_initializer=initializers.get(bias_initializer), - depthwise_regularizer=regularizers.get(depthwise_regularizer), - pointwise_regularizer=regularizers.get(pointwise_regularizer), - bias_regularizer=regularizers.get(bias_regularizer), - activity_regularizer=regularizers.get(activity_regularizer), - depthwise_constraint=constraints.get(depthwise_constraint), - pointwise_constraint=constraints.get(pointwise_constraint), - bias_constraint=constraints.get(bias_constraint), - min_value=None, - max_value=None, - **kwargs - ) - - self.min_value = min_value - self.max_value = max_value - - def call(self, inputs): - depthwise_kernel_size = self.depthwise_kernel.shape[-1] - - if not self.min_value: - self.min_value = [-10000] * depthwise_kernel_size - if not self.max_value: - self.max_value = [10000] * depthwise_kernel_size - - # TODO it's ugly that we can't get the point_wise min/max here - depthwise_kernel, _, _ = quantization.quantize( - self.depthwise_kernel, self.min_value, self.max_value, tf.qint8, axis=3, mode="SCALED" - ) - depthwise_kernel = quantization.dequantize( - depthwise_kernel, - self.min_value, - self.max_value, - axis=3, - mode="SCALED", - ) - - outputs = ops.separable_conv( - inputs, - self.depthwise_kernel, - self.pointwise_kernel, - strides=self.strides, - padding=self.padding, - dilation_rate=self.dilation_rate, - data_format=self.data_format, - ) - - if self.use_bias: - if self.data_format == "channels_last": - bias_shape = (1,) * (self.rank + 1) + (self.filters,) - else: - bias_shape = (1, self.filters) + (1,) * self.rank - bias = ops.reshape(self.bias, bias_shape) - outputs += bias - - if self.activation is not None: - return self.activation(outputs) - return outputs - -else: - - class QSeparableConv2D(SeparableConv): - def __init__( - self, - name, - filters, - kernel_size, - strides=(1, 1), - padding="valid", - data_format=None, - dilation_rate=(1, 1), - depth_multiplier=1, - activation=None, - use_bias=True, - depthwise_initializer="glorot_uniform", - pointwise_initializer="glorot_uniform", - bias_initializer="zeros", - depthwise_regularizer=None, - pointwise_regularizer=None, - bias_regularizer=None, - activity_regularizer=None, - depthwise_constraint=None, - pointwise_constraint=None, - bias_constraint=None, - min_value=None, - max_value=None, +class QSeparableConv2D(SeparableConv): + def __init__( + self, + name, + filters, + kernel_size, + strides=(1, 1), + padding="valid", + data_format=None, + dilation_rate=(1, 1), + depth_multiplier=1, + activation=None, + use_bias=True, + depthwise_initializer="glorot_uniform", + pointwise_initializer="glorot_uniform", + bias_initializer="zeros", + depthwise_regularizer=None, + pointwise_regularizer=None, + bias_regularizer=None, + activity_regularizer=None, + depthwise_constraint=None, + pointwise_constraint=None, + bias_constraint=None, + min_value=None, + max_value=None, + **kwargs + ): + super().__init__( + name=name, + rank=2, + filters=filters, + kernel_size=kernel_size, + strides=strides, + padding=padding, + data_format=data_format, + dilation_rate=dilation_rate, + depth_multiplier=depth_multiplier, + activation=activations.get(activation), + use_bias=use_bias, + depthwise_initializer=initializers.get(depthwise_initializer), + pointwise_initializer=initializers.get(pointwise_initializer), + bias_initializer=initializers.get(bias_initializer), + depthwise_regularizer=regularizers.get(depthwise_regularizer), + pointwise_regularizer=regularizers.get(pointwise_regularizer), + bias_regularizer=regularizers.get(bias_regularizer), + activity_regularizer=regularizers.get(activity_regularizer), + depthwise_constraint=constraints.get(depthwise_constraint), + pointwise_constraint=constraints.get(pointwise_constraint), + bias_constraint=constraints.get(bias_constraint), **kwargs - ): - super().__init__( - name=name, - rank=2, - filters=filters, - kernel_size=kernel_size, - strides=strides, - padding=padding, - data_format=data_format, - dilation_rate=dilation_rate, - depth_multiplier=depth_multiplier, - activation=activations.get(activation), - use_bias=use_bias, - depthwise_initializer=initializers.get(depthwise_initializer), - pointwise_initializer=initializers.get(pointwise_initializer), - bias_initializer=initializers.get(bias_initializer), - depthwise_regularizer=regularizers.get(depthwise_regularizer), - pointwise_regularizer=regularizers.get(pointwise_regularizer), - bias_regularizer=regularizers.get(bias_regularizer), - activity_regularizer=regularizers.get(activity_regularizer), - depthwise_constraint=constraints.get(depthwise_constraint), - pointwise_constraint=constraints.get(pointwise_constraint), - bias_constraint=constraints.get(bias_constraint), - **kwargs - ) - - self.min_value = min_value - self.max_value = max_value - - def call(self, inputs): - depthwise_kernel_size = self.depthwise_kernel.shape[-1] - - if not self.min_value: - self.min_value = [-10000] * depthwise_kernel_size - if not self.max_value: - self.max_value = [10000] * depthwise_kernel_size - - # TODO it's ugly that we can't get the point_wise min/max here - depthwise_kernel, _, _ = quantization.quantize( - self.depthwise_kernel, self.min_value, self.max_value, tf.qint8, axis=3, mode="SCALED" - ) - depthwise_kernel = quantization.dequantize( - depthwise_kernel, - self.min_value, - self.max_value, - axis=3, - mode="SCALED", - ) - - if self.data_format == "channels_last": - strides = (1,) + self.strides + (1,) - else: - strides = (1, 1) + self.strides - - outputs = tf.compat.v1.nn.separable_conv2d( - inputs, - depthwise_kernel, - self.pointwise_kernel, - strides=strides, - padding=self.padding.upper(), - rate=self.dilation_rate, - data_format=conv_utils.convert_data_format(self.data_format, ndim=4), - ) - - if self.use_bias: - outputs = tf.keras.backend.bias_add(outputs, self.bias, data_format=self.data_format) - - if self.activation is not None: - return self.activation(outputs) - - return outputs - - @classmethod - def from_config(cls, config): - return cls(**config) + ) + + self.min_value = min_value + self.max_value = max_value + + def call(self, inputs): + depthwise_kernel_size = self.depthwise_kernel.shape[-1] + + if not self.min_value: + self.min_value = [-10000] * depthwise_kernel_size + if not self.max_value: + self.max_value = [10000] * depthwise_kernel_size + + # TODO it's ugly that we can't get the point_wise min/max here + depthwise_kernel, _, _ = quantization.quantize( + self.depthwise_kernel, self.min_value, self.max_value, tf.qint8, axis=3, mode="SCALED" + ) + depthwise_kernel = quantization.dequantize( + depthwise_kernel, + self.min_value, + self.max_value, + axis=3, + mode="SCALED", + ) + + if self.data_format == "channels_last": + strides = (1,) + self.strides + (1,) + else: + strides = (1, 1) + self.strides + + outputs = tf.compat.v1.nn.separable_conv2d( + inputs, + depthwise_kernel, + self.pointwise_kernel, + strides=strides, + padding=self.padding.upper(), + rate=self.dilation_rate, + data_format=conv_utils.convert_data_format(self.data_format, ndim=4), + ) + + if self.use_bias: + outputs = tf.keras.backend.bias_add(outputs, self.bias, data_format=self.data_format) + + if self.activation is not None: + return self.activation(outputs) + + return outputs + + @classmethod + def from_config(cls, config): + return cls(**config) def initialize_int8_separable_conv2d(fp32_layer): diff --git a/test/3x/tensorflow/keras/test_config.py b/test/3x/tensorflow/keras/test_config.py index 4d7ce12f72c..e79977092de 100644 --- a/test/3x/tensorflow/keras/test_config.py +++ b/test/3x/tensorflow/keras/test_config.py @@ -69,7 +69,7 @@ def build_model(): _, baseline_model_accuracy = model.evaluate(test_images, test_labels, verbose=0) print("Baseline test accuracy:", baseline_model_accuracy) - model.save("baseline_model.keras") + model.save("baseline_model") class Dataset(object): @@ -124,7 +124,7 @@ def test_static_quant_from_dict_default(self): from neural_compressor.tensorflow.keras import get_default_static_quant_config calib_dataloader = MyDataloader(dataset=Dataset()) - fp32_model = keras.models.load_model("baseline_model.keras") + fp32_model = keras.models.load_model("baseline_model") qmodel = quantize_model(fp32_model, get_default_static_quant_config(), calib_dataloader) self.assertIsNotNone(qmodel) @@ -158,7 +158,7 @@ def test_static_quant_from_dict_beginner(self): } } calib_dataloader = MyDataloader(dataset=Dataset()) - fp32_model = keras.models.load_model("baseline_model.keras") + fp32_model = keras.models.load_model("baseline_model") qmodel = quantize_model(fp32_model, quant_config, calib_dataloader) self.assertIsNotNone(qmodel) @@ -181,7 +181,7 @@ def test_static_quant_from_class_default(self): from neural_compressor.tensorflow.keras import StaticQuantConfig calib_dataloader = MyDataloader(dataset=Dataset()) - fp32_model = keras.models.load_model("baseline_model.keras") + fp32_model = keras.models.load_model("baseline_model") quant_config = StaticQuantConfig() qmodel = quantize_model(fp32_model, quant_config, calib_dataloader) self.assertIsNotNone(qmodel) @@ -205,7 +205,7 @@ def test_static_quant_from_class_beginner(self): from neural_compressor.tensorflow.keras import StaticQuantConfig calib_dataloader = MyDataloader(dataset=Dataset()) - fp32_model = keras.models.load_model("baseline_model.keras") + fp32_model = keras.models.load_model("baseline_model") quant_config = StaticQuantConfig( weight_dtype="int8", weight_sym=True, @@ -235,7 +235,7 @@ def test_static_quant_from_dict_advance(self): from neural_compressor.tensorflow import quantize_model calib_dataloader = MyDataloader(dataset=Dataset()) - fp32_model = keras.models.load_model("baseline_model.keras") + fp32_model = keras.models.load_model("baseline_model") quant_config = { "static_quant": { "global": { @@ -290,7 +290,7 @@ def test_static_quant_from_class_advance(self): ) quant_config.set_local("dense", dense_config) # get model and quantize - fp32_model = keras.models.load_model("baseline_model.keras") + fp32_model = keras.models.load_model("baseline_model") qmodel = quantize_model(fp32_model, quant_config, calib_dataloader) self.assertIsNotNone(qmodel) diff --git a/test/3x/tensorflow/quantization/test_smooth_quant.py b/test/3x/tensorflow/quantization/test_smooth_quant.py index ee8f5407d3a..1a74ce0afb3 100644 --- a/test/3x/tensorflow/quantization/test_smooth_quant.py +++ b/test/3x/tensorflow/quantization/test_smooth_quant.py @@ -20,13 +20,15 @@ def build_conv_graph(): "weight", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) conv = tf.nn.conv2d(x_pad, conv_weights, strides=[1, 2, 2, 1], padding="VALID") - + normed = tf.compat.v1.layers.batch_normalization(conv) + conv_weights2 = tf.compat.v1.get_variable( "weight2", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) conv2 = tf.nn.conv2d(top_relu, conv_weights2, strides=[1, 2, 2, 1], padding="SAME") + normed2 = tf.compat.v1.layers.batch_normalization(conv2) + add = tf.raw_ops.Add(x=normed, y=normed2, name="addv2") - add = tf.raw_ops.Add(x=conv, y=conv2, name="addv2") relu = tf.nn.relu(add) relu6 = tf.nn.relu6(relu, name="op_to_store") diff --git a/test/3x/tensorflow/test_autotune.py b/test/3x/tensorflow/test_autotune.py index d5f83e85c7d..9c89f8cd5fc 100644 --- a/test/3x/tensorflow/test_autotune.py +++ b/test/3x/tensorflow/test_autotune.py @@ -59,7 +59,7 @@ def build_model(): _, baseline_model_accuracy = model.evaluate(test_images, test_labels, verbose=0) print("Baseline test accuracy:", baseline_model_accuracy) - tf.saved_model.save(model, "baseline_model") + model.save("baseline_model") class Dataset(object): From c6a255e50a503627c7fca385bb237be0f9d5bffd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 28 Mar 2024 09:45:01 +0000 Subject: [PATCH 08/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../tensorflow/algorithms/static_quant/keras.py | 14 ++++++++------ .../tensorflow/keras/layers/depthwise_conv2d.py | 4 ++-- .../tensorflow/quantization/test_smooth_quant.py | 2 +- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/neural_compressor/tensorflow/algorithms/static_quant/keras.py b/neural_compressor/tensorflow/algorithms/static_quant/keras.py index f4d7b6f5263..3cbfea401c2 100644 --- a/neural_compressor/tensorflow/algorithms/static_quant/keras.py +++ b/neural_compressor/tensorflow/algorithms/static_quant/keras.py @@ -159,7 +159,6 @@ def _check_quantize_format(self, model): else: input_layer_dict[layer_name].append(layer.name) - for layer in model.layers: if layer.__class__.__name__ in self.supported_op: self.conv_format[layer.name] = "s8" @@ -244,9 +243,12 @@ def fuse_conv_bn(conv_weight, bn_weight, conv_type="Conv2D", eps=1.0e-5): else: for bound_node in layer._inbound_nodes: inbound_layer = bound_node.inbound_layers - if not isinstance(inbound_layer, list) and inbound_layer.name in self.bn_weights.keys() \ - and inbound_layer._inbound_nodes[0].inbound_layers.name in self.conv_weights.keys(): - new_bound_nodes.append(bn_inbound_node) + if ( + not isinstance(inbound_layer, list) + and inbound_layer.name in self.bn_weights.keys() + and inbound_layer._inbound_nodes[0].inbound_layers.name in self.conv_weights.keys() + ): + new_bound_nodes.append(bn_inbound_node) else: new_bound_nodes.append(bound_node) @@ -290,7 +292,7 @@ def fuse_conv_bn(conv_weight, bn_weight, conv_type="Conv2D", eps=1.0e-5): bn_fused_model.save(self.tmp_dir) bn_fused_model = tf.keras.models.load_model(self.tmp_dir) - + return bn_fused_model @dump_elapsed_time("Pass quantize model") @@ -844,7 +846,7 @@ def fuse_bn_layers(self, fuse_layers, conv_weights_keys): output_tensor_dict[layer.name] = x if layer.name in self.model.output_names: self.model_outputs.append(x) - + return tf.keras.models.Model(inputs=self.model.inputs, outputs=self.model_outputs) def insert_quant_layers(self, qdq_layer_dict, q_layer_dict=None): diff --git a/neural_compressor/tensorflow/keras/layers/depthwise_conv2d.py b/neural_compressor/tensorflow/keras/layers/depthwise_conv2d.py index 431ce68f762..a3e6dd9b2f4 100644 --- a/neural_compressor/tensorflow/keras/layers/depthwise_conv2d.py +++ b/neural_compressor/tensorflow/keras/layers/depthwise_conv2d.py @@ -79,9 +79,9 @@ def call(self, inputs): depthwise_kernel_size = self.depthwise_kernel.shape[-1] if not self.min_value: - self.min_value = [-10000]*depthwise_kernel_size + self.min_value = [-10000] * depthwise_kernel_size if not self.max_value: - self.max_value = [10000]*depthwise_kernel_size + self.max_value = [10000] * depthwise_kernel_size # add the Q/DQ here kernel, _, _ = quantization.quantize( diff --git a/test/3x/tensorflow/quantization/test_smooth_quant.py b/test/3x/tensorflow/quantization/test_smooth_quant.py index 1a74ce0afb3..5c76eadb9cd 100644 --- a/test/3x/tensorflow/quantization/test_smooth_quant.py +++ b/test/3x/tensorflow/quantization/test_smooth_quant.py @@ -21,7 +21,7 @@ def build_conv_graph(): ) conv = tf.nn.conv2d(x_pad, conv_weights, strides=[1, 2, 2, 1], padding="VALID") normed = tf.compat.v1.layers.batch_normalization(conv) - + conv_weights2 = tf.compat.v1.get_variable( "weight2", [3, 3, 16, 16], initializer=tf.compat.v1.random_normal_initializer() ) From 19deabdd169b4d880544b1fa8af9b8892792d515 Mon Sep 17 00:00:00 2001 From: zehao-intel Date: Thu, 28 Mar 2024 17:45:55 +0800 Subject: [PATCH 09/13] limit tf version for 3x Signed-off-by: zehao-intel --- requirements_tf.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/requirements_tf.txt b/requirements_tf.txt index f8075c2a068..da1544d2939 100644 --- a/requirements_tf.txt +++ b/requirements_tf.txt @@ -3,4 +3,4 @@ psutil py-cpuinfo pydantic pyyaml -tensorflow +tensorflow<=2.15.1 From 100d937a65d78df430363e71b83f70316e403539 Mon Sep 17 00:00:00 2001 From: zehao-intel Date: Thu, 28 Mar 2024 21:12:03 +0800 Subject: [PATCH 10/13] fix ut Signed-off-by: zehao-intel --- .../tensorflow/algorithms/static_quant/keras.py | 6 +++--- test/3x/tensorflow/keras/test_config.py | 12 +++--------- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/neural_compressor/tensorflow/algorithms/static_quant/keras.py b/neural_compressor/tensorflow/algorithms/static_quant/keras.py index 3cbfea401c2..79ed5464a1f 100644 --- a/neural_compressor/tensorflow/algorithms/static_quant/keras.py +++ b/neural_compressor/tensorflow/algorithms/static_quant/keras.py @@ -584,7 +584,7 @@ def tuning_cfg_to_fw(self, tuning_cfg): """Parse tune_config and set framework variables. Args: - tuning_cfg (dict): The dict of tunning config. + tuning_cfg (dict): The dict of tuning config. """ self.quantize_config["calib_iteration"] = tuning_cfg["calib_iteration"] self.quantize_config["device"] = self.device @@ -790,8 +790,8 @@ def _create_input_dict(self, fuse_layers=None, conv_weights_keys=None): """Create a input_layer_dict from model. Args: - fuse_layers: The layers in which fused BNs have been excluded, defualt to be None. - conv_weights_keys: The names of conv layers where BNs are going to be fused, defualt to be None. + fuse_layers: The layers in which fused BNs have been excluded, default to be None. + conv_weights_keys: The names of conv layers where BNs are going to be fused, default to be None. Returns: input_layer_dict: The dict that mapping for layer names to their input layer names. diff --git a/test/3x/tensorflow/keras/test_config.py b/test/3x/tensorflow/keras/test_config.py index e79977092de..2835f2956e5 100644 --- a/test/3x/tensorflow/keras/test_config.py +++ b/test/3x/tensorflow/keras/test_config.py @@ -23,7 +23,7 @@ import numpy as np import tensorflow as tf -from tensorflow import keras +import keras from neural_compressor.common import Logger @@ -56,7 +56,7 @@ def build_model(): ) # Train the digit classification model model.compile( - optimizer="adam", loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=["accuracy"] + optimizer="adam", loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), metrics=["accuracy"] ) model.fit( @@ -260,9 +260,6 @@ def test_static_quant_from_dict_advance(self): dense_checked = False conv_checked = False for layer in qmodel.layers: - if layer.name == "dense": - dense_checked = True - self.assertEqual(layer.__class__.__name__, "QDense") if layer.name == "conv2d": conv_checked = True self.assertEqual(layer.__class__.__name__, "QConv2D") @@ -297,13 +294,10 @@ def test_static_quant_from_class_advance(self): dense_checked = False conv_checked = False for layer in qmodel.layers: - if layer.name == "dense": - dense_checked = True - self.assertEqual(layer.__class__.__name__, "QDense") if layer.name == "conv2d": conv_checked = True self.assertEqual(layer.__class__.__name__, "QConv2D") - + self.assertEqual(dense_checked, True) self.assertEqual(conv_checked, True) From 31ea76db6672c7283cbdaf21e5cafd557b305386 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 28 Mar 2024 13:13:32 +0000 Subject: [PATCH 11/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- test/3x/tensorflow/keras/test_config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/3x/tensorflow/keras/test_config.py b/test/3x/tensorflow/keras/test_config.py index 2835f2956e5..7c0ee4c1b3f 100644 --- a/test/3x/tensorflow/keras/test_config.py +++ b/test/3x/tensorflow/keras/test_config.py @@ -21,9 +21,9 @@ import time import unittest +import keras import numpy as np import tensorflow as tf -import keras from neural_compressor.common import Logger @@ -297,7 +297,7 @@ def test_static_quant_from_class_advance(self): if layer.name == "conv2d": conv_checked = True self.assertEqual(layer.__class__.__name__, "QConv2D") - + self.assertEqual(dense_checked, True) self.assertEqual(conv_checked, True) From 0d0c8c0282e37c29579fb236977fc3bac4e0a675 Mon Sep 17 00:00:00 2001 From: zehao-intel Date: Thu, 28 Mar 2024 21:35:08 +0800 Subject: [PATCH 12/13] remove some checks Signed-off-by: zehao-intel --- test/3x/tensorflow/keras/test_config.py | 37 ------------------------- 1 file changed, 37 deletions(-) diff --git a/test/3x/tensorflow/keras/test_config.py b/test/3x/tensorflow/keras/test_config.py index 2835f2956e5..36fe5319fa7 100644 --- a/test/3x/tensorflow/keras/test_config.py +++ b/test/3x/tensorflow/keras/test_config.py @@ -128,18 +128,12 @@ def test_static_quant_from_dict_default(self): qmodel = quantize_model(fp32_model, get_default_static_quant_config(), calib_dataloader) self.assertIsNotNone(qmodel) - dense_checked = False - conv_checked = False for layer in qmodel.layers: if layer.name == "dense": - dense_checked = True self.assertEqual(layer.__class__.__name__, "QDense") if layer.name == "conv2d": - conv_checked = True self.assertEqual(layer.__class__.__name__, "QConv2D") - self.assertEqual(dense_checked, True) - self.assertEqual(conv_checked, True) def test_static_quant_from_dict_beginner(self): logger.info("test_static_quant_from_dict_beginner") @@ -162,18 +156,12 @@ def test_static_quant_from_dict_beginner(self): qmodel = quantize_model(fp32_model, quant_config, calib_dataloader) self.assertIsNotNone(qmodel) - dense_checked = False - conv_checked = False for layer in qmodel.layers: if layer.name == "dense": - dense_checked = True self.assertEqual(layer.__class__.__name__, "QDense") if layer.name == "conv2d": - conv_checked = True self.assertEqual(layer.__class__.__name__, "QConv2D") - self.assertEqual(dense_checked, True) - self.assertEqual(conv_checked, True) def test_static_quant_from_class_default(self): logger.info("test_static_quant_from_class_default") @@ -186,19 +174,12 @@ def test_static_quant_from_class_default(self): qmodel = quantize_model(fp32_model, quant_config, calib_dataloader) self.assertIsNotNone(qmodel) - dense_checked = False - conv_checked = False for layer in qmodel.layers: if layer.name == "dense": - dense_checked = True self.assertEqual(layer.__class__.__name__, "QDense") if layer.name == "conv2d": - conv_checked = True self.assertEqual(layer.__class__.__name__, "QConv2D") - self.assertEqual(dense_checked, True) - self.assertEqual(conv_checked, True) - def test_static_quant_from_class_beginner(self): logger.info("test_static_quant_from_class_beginner") from neural_compressor.tensorflow import quantize_model @@ -217,18 +198,12 @@ def test_static_quant_from_class_beginner(self): qmodel = quantize_model(fp32_model, quant_config, calib_dataloader) self.assertIsNotNone(qmodel) - dense_checked = False - conv_checked = False for layer in qmodel.layers: if layer.name == "dense": - dense_checked = True self.assertEqual(layer.__class__.__name__, "QDense") if layer.name == "conv2d": - conv_checked = True self.assertEqual(layer.__class__.__name__, "QConv2D") - self.assertEqual(dense_checked, True) - self.assertEqual(conv_checked, True) def test_static_quant_from_dict_advance(self): logger.info("test_static_quant_from_dict_advance") @@ -257,16 +232,10 @@ def test_static_quant_from_dict_advance(self): qmodel = quantize_model(fp32_model, quant_config, calib_dataloader) self.assertIsNotNone(qmodel) - dense_checked = False - conv_checked = False for layer in qmodel.layers: if layer.name == "conv2d": - conv_checked = True self.assertEqual(layer.__class__.__name__, "QConv2D") - self.assertEqual(dense_checked, True) - self.assertEqual(conv_checked, True) - def test_static_quant_from_class_advance(self): logger.info("test_static_quant_from_class_advance") from neural_compressor.tensorflow import quantize_model @@ -291,15 +260,9 @@ def test_static_quant_from_class_advance(self): qmodel = quantize_model(fp32_model, quant_config, calib_dataloader) self.assertIsNotNone(qmodel) - dense_checked = False - conv_checked = False for layer in qmodel.layers: if layer.name == "conv2d": - conv_checked = True self.assertEqual(layer.__class__.__name__, "QConv2D") - - self.assertEqual(dense_checked, True) - self.assertEqual(conv_checked, True) def test_config_from_dict(self): logger.info("test_config_from_dict") From 90fd776048f9a0e858c1cf2ee6e219be0d1fa065 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 28 Mar 2024 13:40:51 +0000 Subject: [PATCH 13/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- test/3x/tensorflow/keras/test_config.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/test/3x/tensorflow/keras/test_config.py b/test/3x/tensorflow/keras/test_config.py index 473a890da7f..c204d52b330 100644 --- a/test/3x/tensorflow/keras/test_config.py +++ b/test/3x/tensorflow/keras/test_config.py @@ -134,7 +134,6 @@ def test_static_quant_from_dict_default(self): if layer.name == "conv2d": self.assertEqual(layer.__class__.__name__, "QConv2D") - def test_static_quant_from_dict_beginner(self): logger.info("test_static_quant_from_dict_beginner") from neural_compressor.tensorflow import quantize_model @@ -162,7 +161,6 @@ def test_static_quant_from_dict_beginner(self): if layer.name == "conv2d": self.assertEqual(layer.__class__.__name__, "QConv2D") - def test_static_quant_from_class_default(self): logger.info("test_static_quant_from_class_default") from neural_compressor.tensorflow import quantize_model @@ -204,7 +202,6 @@ def test_static_quant_from_class_beginner(self): if layer.name == "conv2d": self.assertEqual(layer.__class__.__name__, "QConv2D") - def test_static_quant_from_dict_advance(self): logger.info("test_static_quant_from_dict_advance") from neural_compressor.tensorflow import quantize_model