diff --git a/model-optimizer/extensions/back/compress_quantized_weights.py b/model-optimizer/extensions/back/compress_quantized_weights.py index 62799acc1d1c39..98fbd57f4fd7b2 100644 --- a/model-optimizer/extensions/back/compress_quantized_weights.py +++ b/model-optimizer/extensions/back/compress_quantized_weights.py @@ -6,7 +6,8 @@ import numpy as np from extensions.ops.Cast import Cast -from extensions.ops.elementwise import Sub, Div, Mul, Negative +from extensions.ops.elementwise import Sub, Div, Mul, Negative, Equal +from extensions.ops.select import Select from mo.back.replacement import BackReplacementPattern from mo.graph.graph import Graph, Node from mo.middle.passes.convert_data_type import data_type_str_to_np, np_data_type_to_destination_type, packed_I4 @@ -70,15 +71,7 @@ class CompressQuantizeWeights(BackReplacementPattern): scale = (output_high - output_low) / (input_high - input_low) WARNING: division by zero imposes restriction -- input_high can not be equal to input_low zero_point = input_low - output_low / scale - - TODO: steps 5 and 6 are NOT IMPLEMENTED YET - TODO: DOES LPT NEED IT??? - Step 5: Having zero_point == 0 is really beneficial for performance, so we try to fuse Subtract up to the Constant. - It is not always possible because of the quantized_dtype possible range of values. - - Step 6: (Optional) From the nature of Subtract and Multiply operations they may be optimized out in cases: - zero_point == 0 - scale == 1 + NOTE: if scale == 0 than zero_point is equal to zero too (achieved through Select operation) BENEFITS: Such constant data packing reduces IR size (.bin file size) @@ -186,14 +179,24 @@ def dequantize_data(fake_quantize: Node, dst_type: type, quantized_type: type) - descaled_output_low.in_port(0).connect(out_low) descaled_output_low.in_port(1).connect(scale.out_port(0)) - shift = Sub(graph, {'name': name + '/zero_point'}).create_node() + shift = Sub(graph, {'name': name + '/shift'}).create_node() shift.in_port(0).connect(in_low) shift.in_port(1).connect(descaled_output_low.out_port(0)) + zero = Const(graph, {'name': name + '/zero', 'value': np.array(0, dtype=dst_type)}).create_node() + scale_eq_zero = Equal(graph, {'name': name + '/scale_eq_zero'}).create_node() + scale_eq_zero.in_port(0).connect(scale.out_port(0)) + scale_eq_zero.in_port(1).connect(zero.out_port(0)) + + zero_point = Select(graph, {'name': name + '/zero_point'}).create_node() + zero_point.in_port(0).connect(scale_eq_zero.out_port(0)) + zero_point.in_port(1).connect(zero.out_port(0)) + zero_point.in_port(2).connect(shift.out_port(0)) + # DeQuantize(x) == Mul(Sub(x, zero_point), scale) sub_zp = Sub(graph, {'name': name + '/minus_zp'}).create_node() sub_zp.in_port(0).connect(dequantizing_cast.out_port(0)) - sub_zp.in_port(1).connect(shift.out_port(0)) + sub_zp.in_port(1).connect(zero_point.out_port(0)) mul_scale = Mul(graph, {'name': name + '/mulpiply_by_scale'}).create_node() mul_scale.in_port(0).connect(sub_zp.out_port(0)) @@ -221,6 +224,12 @@ def replace_pattern(self, graph: Graph, match: Dict[str, Node]): class ZeroPointOptimizer(BackReplacementPattern): + r""" + Step 1: Having zero_point == 0 is really beneficial for performance, so we try to fuse Subtract up to the Constant. + It is not always possible because of the quantized_dtype possible range of values. + + Step 2: From the nature of Subtract operation it may be optimized out if zero_point == 0 + """ enabled = True force_clean_up = True @@ -249,16 +258,18 @@ def pattern(self): ) def replace_pattern(self, graph: Graph, match: Dict[str, Node]): + zero_point = match['const_zp'].out_port(0).data.get_value() + assert zero_point is not None + convert = match['convert'] sub = match['sub'] - zero_point = sub.in_port(1).data.get_value() - if zero_point is None or np.allclose(zero_point, 0): + if np.allclose(zero_point, 0): + sub.out_port(0).get_connection().set_source(convert.out_port(0)) return - convert = match['convert'] - dst_type = convert.dst_type - weights = convert.in_port(0).data.get_value() + weights = match['const'].out_port(0).data.get_value() if weights is None or weights.dtype != np.int8: return + dst_type = convert.dst_type int8_zero_point = np.round(zero_point).astype(np.int8) adj_zero_point = (zero_point - int8_zero_point).astype(dst_type) @@ -266,8 +277,8 @@ def replace_pattern(self, graph: Graph, match: Dict[str, Node]): original = weights.astype(dst_type) - zero_point transformed = (weights - int8_zero_point).astype(np.int8) - adj_zero_point - if not np.allclose(original, transformed) or not np.allclose(adj_zero_point, 0): + if not np.allclose(original, transformed) or not np.allclose(adj_zero_point, 0, atol=1.e-04): return match['const_d']['value'] = (weights - int8_zero_point).astype(np.int8) - match['const_zp_d']['value'] = np.zeros(adj_zero_point.shape, dst_type) + sub.out_port(0).get_connection().set_source(convert.out_port(0)) diff --git a/model-optimizer/unit_tests/extensions/back/compress_quantized_weights_test.py b/model-optimizer/unit_tests/extensions/back/compress_quantized_weights_test.py index 5e4aa87b525883..45d977beb55da0 100644 --- a/model-optimizer/unit_tests/extensions/back/compress_quantized_weights_test.py +++ b/model-optimizer/unit_tests/extensions/back/compress_quantized_weights_test.py @@ -254,10 +254,42 @@ class ZeroPointOptimizerTestClass(unittest.TestCase): @generate(*[ ([-10, 7], [-1], [-9, 8], [0]), ([-10, 7], [-0.99999999], [-9, 8], [0]), + ]) + def test_zero_point_optimization(self, weights, zero_point, adj_weights, adj_zero_point): + nodes = lambda w, zp: { + **valued_const_with_data('weights', np.array(w, dtype=np.int8)), + **regular_op_with_shaped_data( + 'cast', len(w), {'type': 'Convert', 'op': 'Cast', 'infer': Cast.infer, 'dst_type': np.float32}), + **valued_const_with_data('zp', np.array(zp, dtype=np.float32)), + **regular_op_with_shaped_data( + 'sub', len(w), + {'type': 'Subtract', 'op': 'Sub', 'infer': lambda node: eltwise_infer(node, Sub.operation)}), + **result() + } + edges = [ + *connect("weights:0", "0:cast"), + *connect("cast:0", "0:sub"), + *connect("zp:0", "1:sub"), + *connect("sub:0", "0:output"), + ] + graph = build_graph(nodes(weights, zero_point), edges, nodes_with_edges_only=True) + ZeroPointOptimizer().find_and_replace_pattern(graph) + graph.clean_up() + + graph_ref = build_graph(nodes(adj_weights, adj_zero_point), [ + *connect("weights:0", "0:cast"), + *connect("cast:0", "0:output"), + ], nodes_with_edges_only=True) + graph_ref.clean_up() + + (flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True) + self.assertTrue(flag, resp) + + @generate(*[ ([-128, 7], [1], [-128, 7], [1]), ([127, 7], [-1], [127, 7], [-1]), ]) - def test_zero_point_optimization(self, weights, zero_point, adj_weights, adj_zero_point): + def test_negative_zero_point_optimization(self, weights, zero_point, adj_weights, adj_zero_point): nodes = lambda w, zp: { **valued_const_with_data('weights', np.array(w, dtype=np.int8)), **regular_op_with_shaped_data(