Skip to content

Commit

Permalink
Support fuse small ops to FusedInstanceNorm (intel#1273)
Browse files Browse the repository at this point in the history
  • Loading branch information
lvliang-intel authored Sep 23, 2022
1 parent c98c428 commit 428fa76
Show file tree
Hide file tree
Showing 18 changed files with 848 additions and 176 deletions.
18 changes: 13 additions & 5 deletions neural_compressor/adaptor/inteltensorflow.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,12 @@
uint8: ['Conv2D', 'Conv3D', 'DepthwiseConv2dNative', 'MatMul', 'BatchMatMul', 'BatchMatMulV2', 'ConcatV2', 'MaxPool', 'MaxPool3D', 'AvgPool']
bf16: ["Conv2D", "Conv2DBackpropFilter", "Conv2DBackpropInput", "Conv3D", "Conv3DBackpropFilterV2", "Conv3DBackpropInputV2",
"DepthwiseConv2dNative", "DepthwiseConv2dNativeBackpropFilter", "DepthwiseConv2dNativeBackpropInput", "GRUBlockCell",
"AUGRUBlockCell", "MklGRU", "MklAUGRU", "MatMul", "BatchMatMul", "BatchMatMulV2", # allow_list
"AUGRUBlockCell", "MklGRU", "MklAUGRU", "MatMul", "BatchMatMul", "BatchMatMulV2", "Einsum", # allow_list
"Add", "AddN", "AddV2", "AvgPool", "AvgPool3D", "AvgPool3DGrad", "AvgPoolGrad", "BiasAdd", "BiasAddGrad", "BiasAddV1",
"Erf", "FusedBatchNormV2", "FusedBatchNormGradV2", "FusedBatchNormV3", "FusedBatchNormGradV3", "LeakyRelu", "LeakyReluGrad",
"Mul", "Sub", "Elu", "EluGrad", "FloorDiv", "_FusedBatchNormEx", "Log", "Log1p", "LogSoftmax", "Prod", "RealDiv", "Reciprocal",
"Selu", "SeluGrad", "Sigmoid", "SigmoidGrad", "Softmax", "Softplus", "SoftplusGrad", "Softsign", "SoftsignGrad", "Sqrt",
"Tanh", "TanhGrad", #infer_list
"Mean", "Mul", "Sub", "Elu", "EluGrad", "FloorDiv", "_FusedBatchNormEx", "Log", "Log1p", "LogSoftmax", "Prod", "RealDiv",
"Reciprocal", "Rsqrt", "Selu", "SeluGrad", "Sigmoid", "SigmoidGrad", "Softmax", "Softplus", "SoftplusGrad", "Softsign",
"SoftsignGrad", "Sqrt", "SquaredDifference", "Tanh", "TanhGrad", #infer_list
"Abs", "ArgMax","ArgMin","BatchToSpace","BatchToSpaceND","BroadcastTo","Ceil","CheckNumerics","ClipByValue","Concat","ConcatV2",
"DepthToSpace","DynamicPartition","DynamicStitch","EnsureShape","Enter","Equal","Exit","ExpandDims","Fill","Floor","Gather",
"GatherNd","GatherV2","Greater","GreaterEqual","Identity","IsFinite","IsInf","IsNan","Less","LessEqual","Max","Maximum","MaxPool",
Expand Down Expand Up @@ -337,7 +337,15 @@
'Dequantize + Conv3D + Add + Relu6 + QuantizeV2',
'Dequantize + Conv3D + AddV2 + Relu6 + QuantizeV2',
'Dequantize + Conv3D + Eelu + QuantizeV2',
'Dequantize + Conv3D + LeakyRelu + QuantizeV2'
'Dequantize + Conv3D + LeakyRelu + QuantizeV2',
'Dequantize + Conv3D + BiasAdd + Relu + QuantizeV2',
'Dequantize + Conv3D + BiasAdd + Relu6 + QuantizeV2',
'Dequantize + Conv3D + BiasAdd + Eelu + QuantizeV2',
'Dequantize + Conv3D + BiasAdd + LeakyRelu + QuantizeV2',
'Dequantize + Conv3D + Add + Relu + QuantizeV2',
'Dequantize + Conv3D + Add + Relu6 + QuantizeV2',
'Dequantize + Conv3D + Add + Eelu + QuantizeV2',
'Dequantize + Conv3D + Add + LeakyRelu + QuantizeV2'

]

Expand Down
33 changes: 19 additions & 14 deletions neural_compressor/adaptor/tf_utils/graph_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@
from .graph_rewriter.int8.fuse_matmul_requantize import FuseMatMulRequantizeDequantizeTransformer
from .graph_rewriter.int8.fuse_matmul_requantize import FuseMatMulRequantizeNewAPITransformer
from .graph_rewriter.int8.fuse_matmul_requantize import FuseMatMulRequantizeDequantizeNewAPITransformer
from .graph_rewriter.int8.fuse_conv_redundant_dequantize import FuseConvRedundantDequantizeTransformer
from .graph_rewriter.int8.fuse_matmul_redundant_dequantize import FuseMatMulRedundantDequantizeTransformer
from .graph_rewriter.int8.scale_propagation import ScaleProPagationTransformer
from .graph_rewriter.bf16.bf16_convert import BF16Convert
from .graph_rewriter.int8.post_quantized_op_cse import PostCseOptimizer
Expand Down Expand Up @@ -333,6 +335,11 @@ def convert(self):
if (len(self.bf16_ops) > 0 and self.performance_only) or \
(os.getenv('MIX_PRECISION_TEST') == '1'):
model = self.bf16_convert()

if self.new_api:
model.graph_def = FuseConvRedundantDequantizeTransformer(model.graph_def).do_transformation()
model.graph_def = FuseMatMulRedundantDequantizeTransformer(model.graph_def).do_transformation()

post_cse_graph_def = PostCseOptimizer(model.graph_def).do_transformation()
post_hostconst_graph_def = PostHostConstConverter(post_cse_graph_def).do_transformation()
post_hostconst_graph_def.library.CopyFrom(self.model.graph_def.library)
Expand Down Expand Up @@ -448,18 +455,18 @@ def quantize(self):
sampling_graph_def = copy.deepcopy(self._fp32_model.graph_def)
# TODO: this is a workaround to make Min/Max node be completly eliminated in int8 graph
# after enabling pad+conv2d in new API.
if self.new_api:
non_pad_ops = list(list(set(self.fp32_ops).union(set(self.bf16_ops))))
sampling_graph_def = FusePadWithFP32Conv2DOptimizer(
sampling_graph_def,
non_pad_ops,
self._tmp_model.input_node_names,
self.op_wise_config,
self.new_api).do_transformation()

non_pad_ops = list(list(set(self.fp32_ops).union(set(self.bf16_ops))))
sampling_graph_def = FusePadWithFP32Conv2DOptimizer(
sampling_graph_def,
non_pad_ops,
self._tmp_model.input_node_names,
self.op_wise_config,
self.new_api).do_transformation()

for i in self.quantized_node_info:
sampling_graph_def, output_names = InsertPrintMinMaxNode(
sampling_graph_def, i[0], i[-1]).do_transformation()
sampling_graph_def, i[0], i[-1], self.new_api).do_transformation()
output_tensor_names.extend(output_names)
if self.quantized_node_info:
sampling_graph_def.library.CopyFrom(self.model.graph_def.library)
Expand Down Expand Up @@ -750,7 +757,7 @@ def _insert_qdq_pairs(self):

for i in self.quantized_node_info:
sampling_graph_def, output_names = InsertPrintMinMaxNode(
sampling_graph_def, i[0], i[-1]).do_transformation()
sampling_graph_def, i[0], i[-1], self.new_api).do_transformation()
output_tensor_names.extend(output_names)


Expand All @@ -763,12 +770,9 @@ def _insert_qdq_pairs(self):
self._inference(self._sampling_model)
self._calibration_data = Helper.gen_valid_sampling_log(tmp_dump_file)

self._tmp_graph_def.library.CopyFrom(self.model.graph_def.library)
self._tmp_model.graph_def = self._tmp_graph_def

# Insert QDQ pattern
self._tmp_graph_def = GenerateGraphWithQDQPattern(
self._tmp_model, self._calibration_data, self.op_wise_config,
self._tmp_graph_def, self._calibration_data, self.op_wise_config,
self.fake_quant, self.fp32_ops, self.bf16_ops, self.quantized_node_info,
self.device, self.performance_only, self.itex_mode).do_transformation()

Expand Down Expand Up @@ -819,6 +823,7 @@ def _convert_qdq(self):
self.performance_only,
self.itex_mode).do_transform()
self.exclude_node_names=exclude_node_names

if len(self._calibration_data) > 0:
self._freeze_requantization_ranges(self._kl_op_dict)
self._fuse_requantize_with_fused_quantized_node()
Loading

0 comments on commit 428fa76

Please sign in to comment.