From 0b64a9fbf4f9a2d5418e6f0b0a2a4b68ea531e02 Mon Sep 17 00:00:00 2001 From: ZihengJiang Date: Mon, 3 Dec 2018 17:59:08 -0800 Subject: [PATCH] [QUANTIZE] Update. --- python/tvm/relay/quantize/quantize.py | 8 +++----- python/tvm/relay/quantize/quantize_ops.py | 12 ++---------- src/relay/pass/quantize.cc | 23 ++++------------------- 3 files changed, 9 insertions(+), 34 deletions(-) diff --git a/python/tvm/relay/quantize/quantize.py b/python/tvm/relay/quantize/quantize.py index 471f901a54695..4ab62ceace8ba 100644 --- a/python/tvm/relay/quantize/quantize.py +++ b/python/tvm/relay/quantize/quantize.py @@ -22,7 +22,6 @@ # TODO: # - gpu -# - realize class QFieldKind(object): @@ -72,7 +71,6 @@ def get_current_qconfig(): def get_config_bit(kind): cfg = get_current_qconfig() return cfg.bit_dict[kind] - # return _expr.Integer(cfg.bit_dict[kind]) @register_relay_node @@ -98,9 +96,9 @@ def attach_simulated_quantize(data, kind): True, "round", kind) -def _build_module(graph, params=None): - model, lib, params = _build(graph, target='llvm', params=params) - module = _runtime.create(model, lib, tvm.cpu(0)) +def _build_module(graph, params=None, target='llvm', ctx=tvm.cpu(0)): + model, lib, params = _build(graph, target=target, params=params) + module = _runtime.create(model, lib, ctx) module.set_input(**params) return module diff --git a/python/tvm/relay/quantize/quantize_ops.py b/python/tvm/relay/quantize/quantize_ops.py index bdcc257f2908b..31f37421ec05c 100644 --- a/python/tvm/relay/quantize/quantize_ops.py +++ b/python/tvm/relay/quantize/quantize_ops.py @@ -60,7 +60,7 @@ def multiply_rewrite(ref_call, new_args, ctx): raise ValueError -#@register_qfield_rewrite("add") +@register_qfield_rewrite("add") def add_rewrite(ref_call, new_args, ctx): cfg = get_current_qconfig() if cfg.counter <= cfg.skip_k_conv: @@ -136,22 +136,14 @@ def simulated_quantize_compute(attrs, inputs, output_type, target): return [topi.identity(data)] # simulate rounding error - # data = debug_print(data, 'original_data') scaled_data = topi.divide(data, scale) - # scaled_data = debug_print(scaled_data, 'scaled_data') clipped_data = topi.maximum(topi.minimum(scaled_data, clip_max), clip_min) round_data = topi.round(clipped_data) - # round_data = debug_print(round_data, 'round_data') # recover data rdata = topi.multiply(round_data, scale) return [rdata] -def schedule_naive(attrs, outputs, target): - s = tvm.create_schedule([x.op for x in outputs]) - return s - - -_reg.register_schedule("simulated_quantize", schedule_naive) +_reg.register_schedule("simulated_quantize", _reg.schedule_injective) _reg.register_pattern("simulated_quantize", _reg.OpPattern.OPAQUE) diff --git a/src/relay/pass/quantize.cc b/src/relay/pass/quantize.cc index ef430854348c7..d71bf953f400c 100644 --- a/src/relay/pass/quantize.cc +++ b/src/relay/pass/quantize.cc @@ -387,38 +387,23 @@ Expr AddQStateRewrite(const Call& ref_call, } Expr ret = ForwardOp(ref_call, {ldata, rdata}); return QIntStateNode::make(ret, dom_scale, 32, dtype); - } else { - // do add in real domain for now - // QRealState lhs = new_args[0].as()->Convert2Real(); - // QRealState rhs = new_args[1].as()->Convert2Real(); - CHECK(!new_args[0]->derived_from() && !new_args[1]->derived_from()); - Expr ret = ForwardOp(ref_call, new_args); - return ret; - // return QRealStateNode::make(ret); } + return Expr(nullptr); } -//RELAY_REGISTER_OP("add") -//.set_attr("FQStateRewrite", AddQStateRewrite); +RELAY_REGISTER_OP("add") +.set_attr("FQStateRewrite", AddQStateRewrite); Expr ReluQStateRewrite(const Call& ref_call, const Array& new_args, const NodeRef& ctx) { - return Expr(nullptr); CHECK_EQ(new_args.size(), 1); if (const auto* n = new_args[0].as()) { Expr ret = ForwardOp(ref_call, {n->data}); return QIntStateNode::make(ret, n->dom_scale, n->safe_nbit, n->dtype); - // } else if (const auto* n = new_args[0].as()){ - // Expr ret = ForwardOp(ref_call, {n->data}); - // return QRealStateNode::make(ret); - } else { - CHECK(!new_args[0]->derived_from()); - return new_args[0]; - // LOG(FATAL) << "wrong"; } - return new_args[0]; + return Expr(nullptr); } RELAY_REGISTER_OP("nn.relu")