Skip to content

Commit

Permalink
Zero point optimization (#6683)
Browse files Browse the repository at this point in the history
* Zero point optimization

* Expand the equality to zero criteria
  • Loading branch information
Evgenya Stepyreva authored Jul 21, 2021
1 parent 92fdda5 commit a3825ba
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 20 deletions.
49 changes: 30 additions & 19 deletions model-optimizer/extensions/back/compress_quantized_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import numpy as np

from extensions.ops.Cast import Cast
from extensions.ops.elementwise import Sub, Div, Mul, Negative
from extensions.ops.elementwise import Sub, Div, Mul, Negative, Equal
from extensions.ops.select import Select
from mo.back.replacement import BackReplacementPattern
from mo.graph.graph import Graph, Node
from mo.middle.passes.convert_data_type import data_type_str_to_np, np_data_type_to_destination_type, packed_I4
Expand Down Expand Up @@ -70,15 +71,7 @@ class CompressQuantizeWeights(BackReplacementPattern):
scale = (output_high - output_low) / (input_high - input_low)
WARNING: division by zero imposes restriction -- input_high can not be equal to input_low
zero_point = input_low - output_low / scale
TODO: steps 5 and 6 are NOT IMPLEMENTED YET
TODO: DOES LPT NEED IT???
Step 5: Having zero_point == 0 is really beneficial for performance, so we try to fuse Subtract up to the Constant.
It is not always possible because of the quantized_dtype possible range of values.
Step 6: (Optional) From the nature of Subtract and Multiply operations they may be optimized out in cases:
zero_point == 0
scale == 1
NOTE: if scale == 0 than zero_point is equal to zero too (achieved through Select operation)
BENEFITS:
Such constant data packing reduces IR size (.bin file size)
Expand Down Expand Up @@ -186,14 +179,24 @@ def dequantize_data(fake_quantize: Node, dst_type: type, quantized_type: type) -
descaled_output_low.in_port(0).connect(out_low)
descaled_output_low.in_port(1).connect(scale.out_port(0))

shift = Sub(graph, {'name': name + '/zero_point'}).create_node()
shift = Sub(graph, {'name': name + '/shift'}).create_node()
shift.in_port(0).connect(in_low)
shift.in_port(1).connect(descaled_output_low.out_port(0))

zero = Const(graph, {'name': name + '/zero', 'value': np.array(0, dtype=dst_type)}).create_node()
scale_eq_zero = Equal(graph, {'name': name + '/scale_eq_zero'}).create_node()
scale_eq_zero.in_port(0).connect(scale.out_port(0))
scale_eq_zero.in_port(1).connect(zero.out_port(0))

zero_point = Select(graph, {'name': name + '/zero_point'}).create_node()
zero_point.in_port(0).connect(scale_eq_zero.out_port(0))
zero_point.in_port(1).connect(zero.out_port(0))
zero_point.in_port(2).connect(shift.out_port(0))

# DeQuantize(x) == Mul(Sub(x, zero_point), scale)
sub_zp = Sub(graph, {'name': name + '/minus_zp'}).create_node()
sub_zp.in_port(0).connect(dequantizing_cast.out_port(0))
sub_zp.in_port(1).connect(shift.out_port(0))
sub_zp.in_port(1).connect(zero_point.out_port(0))

mul_scale = Mul(graph, {'name': name + '/mulpiply_by_scale'}).create_node()
mul_scale.in_port(0).connect(sub_zp.out_port(0))
Expand Down Expand Up @@ -221,6 +224,12 @@ def replace_pattern(self, graph: Graph, match: Dict[str, Node]):


class ZeroPointOptimizer(BackReplacementPattern):
r"""
Step 1: Having zero_point == 0 is really beneficial for performance, so we try to fuse Subtract up to the Constant.
It is not always possible because of the quantized_dtype possible range of values.
Step 2: From the nature of Subtract operation it may be optimized out if zero_point == 0
"""
enabled = True
force_clean_up = True

Expand Down Expand Up @@ -249,25 +258,27 @@ def pattern(self):
)

def replace_pattern(self, graph: Graph, match: Dict[str, Node]):
zero_point = match['const_zp'].out_port(0).data.get_value()
assert zero_point is not None
convert = match['convert']
sub = match['sub']
zero_point = sub.in_port(1).data.get_value()
if zero_point is None or np.allclose(zero_point, 0):
if np.allclose(zero_point, 0):
sub.out_port(0).get_connection().set_source(convert.out_port(0))
return

convert = match['convert']
dst_type = convert.dst_type
weights = convert.in_port(0).data.get_value()
weights = match['const'].out_port(0).data.get_value()
if weights is None or weights.dtype != np.int8:
return
dst_type = convert.dst_type

int8_zero_point = np.round(zero_point).astype(np.int8)
adj_zero_point = (zero_point - int8_zero_point).astype(dst_type)

original = weights.astype(dst_type) - zero_point
transformed = (weights - int8_zero_point).astype(np.int8) - adj_zero_point

if not np.allclose(original, transformed) or not np.allclose(adj_zero_point, 0):
if not np.allclose(original, transformed) or not np.allclose(adj_zero_point, 0, atol=1.e-04):
return

match['const_d']['value'] = (weights - int8_zero_point).astype(np.int8)
match['const_zp_d']['value'] = np.zeros(adj_zero_point.shape, dst_type)
sub.out_port(0).get_connection().set_source(convert.out_port(0))
Original file line number Diff line number Diff line change
Expand Up @@ -254,10 +254,42 @@ class ZeroPointOptimizerTestClass(unittest.TestCase):
@generate(*[
([-10, 7], [-1], [-9, 8], [0]),
([-10, 7], [-0.99999999], [-9, 8], [0]),
])
def test_zero_point_optimization(self, weights, zero_point, adj_weights, adj_zero_point):
nodes = lambda w, zp: {
**valued_const_with_data('weights', np.array(w, dtype=np.int8)),
**regular_op_with_shaped_data(
'cast', len(w), {'type': 'Convert', 'op': 'Cast', 'infer': Cast.infer, 'dst_type': np.float32}),
**valued_const_with_data('zp', np.array(zp, dtype=np.float32)),
**regular_op_with_shaped_data(
'sub', len(w),
{'type': 'Subtract', 'op': 'Sub', 'infer': lambda node: eltwise_infer(node, Sub.operation)}),
**result()
}
edges = [
*connect("weights:0", "0:cast"),
*connect("cast:0", "0:sub"),
*connect("zp:0", "1:sub"),
*connect("sub:0", "0:output"),
]
graph = build_graph(nodes(weights, zero_point), edges, nodes_with_edges_only=True)
ZeroPointOptimizer().find_and_replace_pattern(graph)
graph.clean_up()

graph_ref = build_graph(nodes(adj_weights, adj_zero_point), [
*connect("weights:0", "0:cast"),
*connect("cast:0", "0:output"),
], nodes_with_edges_only=True)
graph_ref.clean_up()

(flag, resp) = compare_graphs(graph, graph_ref, 'output', check_op_attrs=True)
self.assertTrue(flag, resp)

@generate(*[
([-128, 7], [1], [-128, 7], [1]),
([127, 7], [-1], [127, 7], [-1]),
])
def test_zero_point_optimization(self, weights, zero_point, adj_weights, adj_zero_point):
def test_negative_zero_point_optimization(self, weights, zero_point, adj_weights, adj_zero_point):
nodes = lambda w, zp: {
**valued_const_with_data('weights', np.array(w, dtype=np.int8)),
**regular_op_with_shaped_data(
Expand Down

0 comments on commit a3825ba

Please sign in to comment.