Skip to content

Commit

Permalink
enable quantization
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiics committed May 13, 2019
1 parent 1a42825 commit 0782e4c
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 52 deletions.
79 changes: 73 additions & 6 deletions python/tvm/relay/quantize/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,77 @@ def realize(graph):
return _quantize.realize(graph)


def optimize(func, params=None):
""" Perform "SimplifyInference", "FoldScaleAxis", "FoldConstant", and
"CanonicalizeOps" optimization before quantization.
# TODO(zhiics) These passes are executed one by one so far. We need to
# move them to the pass manager.
Parameters
---------
func: tvm.relay.Function
The original Relay function to be optimized.
params : dict of str to tvm.NDArray
Input parameters to the graph that do not change
during inference time. Used for constant folding.
Returns
-------
ret: tvm.relay.Function
The graph after quantization
"""

opt_passes = ["SimplifyInference",
"FoldScaleAxis",
"FoldConstant",
"CanonicalizeOps"]

cfg = _build.build_config(add_pass=opt_passes)

if params:
name_dict = {}
for arg in func.params:
name = arg.name_hint
if name in name_dict:
name_dict[name] = None
else:
name_dict[name] = arg
bind_dict = {}
for k, v in params.items():
if k not in name_dict:
continue
arg = name_dict[k]
if arg is None:
raise ValueError("Multiple args in the function have name %s" % k)
bind_dict[arg] = _expr.const(v)
func = _expr.bind(func, bind_dict)

if "SimplifyInference" in cfg.add_pass:
func = _ir_pass.infer_type(func)
func = _ir_pass.simplify_inference(func)

if "FoldConstant" in cfg.add_pass:
func = _ir_pass.fold_constant(func)

if "FoldScaleAxis" in cfg.add_pass:
func = _ir_pass.infer_type(func)
func = _ir_pass.backward_fold_scale_axis(func)
func = _ir_pass.infer_type(func)
func = _ir_pass.forward_fold_scale_axis(func)
func = _ir_pass.fold_constant(func)

if "CanonicalizeOps" in cfg.add_pass:
func = _ir_pass.infer_type(func)
func = _ir_pass.canonicalize_ops(func)

if "FoldConstant" in cfg.add_pass:
func = _ir_pass.fold_constant(func)

return func


def quantize(graph, params=None, dataset=None):
""" The quantization procedure. Before running the three main
procedure of quantization, "annotate", "calibrate" and "realize"
Expand All @@ -292,12 +363,8 @@ def quantize(graph, params=None, dataset=None):
ret: Function
The graph after quantization
"""
opt_passes = ["SimplifyInference",
"FoldScaleAxis",
"FoldConstant",
"CanonicalizeOps"]
with _build.build_config(add_pass=opt_passes):
graph = _build.optimize(graph, params=params)
# TODO(zhiics) Move this to the pass manager.
graph = optimize(graph, params)

graph = annotate(graph)
graph = calibrate(graph, dataset)
Expand Down
92 changes: 46 additions & 46 deletions tests/python/relay/test_pass_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,54 +47,54 @@ def test_simulated_quantize():
assert out.args[3].checked_type == relay.ty.TensorType(tuple(), "float32")


# def test_quantize_pass():
# def quantize_weight(arr):
# maximum = np.amax(np.abs(arr.asnumpy()))
# scale = 2**math.ceil(math.log(maximum, 2))
# out = np.around(arr.asnumpy() / scale * 128).astype('int8')
# out = np.clip(out, -127, 127)
# return relay.const(out, 'int8')
#
# n, c, h, w = 1, 3, 224, 224
# def make_graph(data):
# weight = relay.var("conv_weight")
# out = relay.nn.conv2d(data, weight, kernel_size=(3, 3), padding=(1, 1), channels=c)
# out = relay.Function(relay.ir_pass.free_vars(out), out)
# return out
#
# def make_qgraph(data, weight):
# out = data * relay.const(32.0)
# out = relay.round(out)
# out = relay.clip(out, a_min=-127, a_max=127)
# out = out.astype('int8')
#
# out = relay.nn.conv2d(out, weight, kernel_size=(3, 3),
# padding=(1, 1), channels=c, out_dtype='int32')
# out = out.astype('float32')
# out = relay.multiply(out, relay.const(0.00024414062))
# out = relay.Function(relay.ir_pass.free_vars(out), out)
# return out
#
# data = relay.var("data", relay.TensorType((n, c, h, w), "float32"))
# graph = make_graph(data)
# dataset, params = make_dataset(graph, 10)
#
# with qtz.qconfig(skip_k_conv=0, global_scale=4.0,
# round_for_shift=False, store_lowbit_output=False):
# qgraph0 = qtz.quantize(graph, params)
# qgraph0 = relay.ir_pass.infer_type(qgraph0)
#
# conv_weight = quantize_weight(params['conv_weight'])
# qgraph1 = make_qgraph(data, conv_weight)
# qgraph1 = relay.ir_pass.infer_type(qgraph1)
#
# graph = relay.create_executor('graph')
# res0 = graph.evaluate(qgraph0)(dataset[0]['data'])
# res1 = graph.evaluate(qgraph1)(dataset[0]['data'])
# tvm.testing.assert_allclose(res0.asnumpy(), res1.asnumpy(), rtol=1e-3)
def test_quantize_pass():
def quantize_weight(arr):
maximum = np.amax(np.abs(arr.asnumpy()))
scale = 2**math.ceil(math.log(maximum, 2))
out = np.around(arr.asnumpy() / scale * 128).astype('int8')
out = np.clip(out, -127, 127)
return relay.const(out, 'int8')

n, c, h, w = 1, 3, 224, 224
def make_graph(data):
weight = relay.var("conv_weight")
out = relay.nn.conv2d(data, weight, kernel_size=(3, 3), padding=(1, 1), channels=c)
out = relay.Function(relay.ir_pass.free_vars(out), out)
return out

def make_qgraph(data, weight):
out = data * relay.const(32.0)
out = relay.round(out)
out = relay.clip(out, a_min=-127, a_max=127)
out = out.astype('int8')

out = relay.nn.conv2d(out, weight, kernel_size=(3, 3),
padding=(1, 1), channels=c, out_dtype='int32')
out = out.astype('float32')
out = relay.multiply(out, relay.const(0.00024414062))
out = relay.Function(relay.ir_pass.free_vars(out), out)
return out

data = relay.var("data", relay.TensorType((n, c, h, w), "float32"))
graph = make_graph(data)
dataset, params = make_dataset(graph, 10)

with qtz.qconfig(skip_k_conv=0, global_scale=4.0,
round_for_shift=False, store_lowbit_output=False):
qgraph0 = qtz.quantize(graph, params)
qgraph0 = relay.ir_pass.infer_type(qgraph0)

conv_weight = quantize_weight(params['conv_weight'])
qgraph1 = make_qgraph(data, conv_weight)
qgraph1 = relay.ir_pass.infer_type(qgraph1)

graph = relay.create_executor('graph')
res0 = graph.evaluate(qgraph0)(dataset[0]['data'])
res1 = graph.evaluate(qgraph1)(dataset[0]['data'])
tvm.testing.assert_allclose(res0.asnumpy(), res1.asnumpy(), rtol=1e-3)


if __name__ == "__main__":
np.random.seed(42)
test_simulated_quantize()
# test_quantize_pass()
test_quantize_pass()

0 comments on commit 0782e4c

Please sign in to comment.