Skip to content

Commit

Permalink
[QUANTIZE] Update.
Browse files Browse the repository at this point in the history
  • Loading branch information
ZihengJiang committed Dec 4, 2018
1 parent d92f41e commit 0b64a9f
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 34 deletions.
8 changes: 3 additions & 5 deletions python/tvm/relay/quantize/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@

# TODO:
# - gpu
# - realize


class QFieldKind(object):
Expand Down Expand Up @@ -72,7 +71,6 @@ def get_current_qconfig():
def get_config_bit(kind):
cfg = get_current_qconfig()
return cfg.bit_dict[kind]
# return _expr.Integer(cfg.bit_dict[kind])


@register_relay_node
Expand All @@ -98,9 +96,9 @@ def attach_simulated_quantize(data, kind):
True, "round", kind)


def _build_module(graph, params=None):
model, lib, params = _build(graph, target='llvm', params=params)
module = _runtime.create(model, lib, tvm.cpu(0))
def _build_module(graph, params=None, target='llvm', ctx=tvm.cpu(0)):
model, lib, params = _build(graph, target=target, params=params)
module = _runtime.create(model, lib, ctx)
module.set_input(**params)
return module

Expand Down
12 changes: 2 additions & 10 deletions python/tvm/relay/quantize/quantize_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def multiply_rewrite(ref_call, new_args, ctx):
raise ValueError


#@register_qfield_rewrite("add")
@register_qfield_rewrite("add")
def add_rewrite(ref_call, new_args, ctx):
cfg = get_current_qconfig()
if cfg.counter <= cfg.skip_k_conv:
Expand Down Expand Up @@ -136,22 +136,14 @@ def simulated_quantize_compute(attrs, inputs, output_type, target):
return [topi.identity(data)]

# simulate rounding error
# data = debug_print(data, 'original_data')
scaled_data = topi.divide(data, scale)
# scaled_data = debug_print(scaled_data, 'scaled_data')
clipped_data = topi.maximum(topi.minimum(scaled_data, clip_max), clip_min)
round_data = topi.round(clipped_data)
# round_data = debug_print(round_data, 'round_data')

# recover data
rdata = topi.multiply(round_data, scale)
return [rdata]


def schedule_naive(attrs, outputs, target):
s = tvm.create_schedule([x.op for x in outputs])
return s


_reg.register_schedule("simulated_quantize", schedule_naive)
_reg.register_schedule("simulated_quantize", _reg.schedule_injective)
_reg.register_pattern("simulated_quantize", _reg.OpPattern.OPAQUE)
23 changes: 4 additions & 19 deletions src/relay/pass/quantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -387,38 +387,23 @@ Expr AddQStateRewrite(const Call& ref_call,
}
Expr ret = ForwardOp(ref_call, {ldata, rdata});
return QIntStateNode::make(ret, dom_scale, 32, dtype);
} else {
// do add in real domain for now
// QRealState lhs = new_args[0].as<QStateNode>()->Convert2Real();
// QRealState rhs = new_args[1].as<QStateNode>()->Convert2Real();
CHECK(!new_args[0]->derived_from<TempExprNode>() && !new_args[1]->derived_from<TempExprNode>());
Expr ret = ForwardOp(ref_call, new_args);
return ret;
// return QRealStateNode::make(ret);
}
return Expr(nullptr);
}

//RELAY_REGISTER_OP("add")
//.set_attr<FForwardRewrite>("FQStateRewrite", AddQStateRewrite);
RELAY_REGISTER_OP("add")
.set_attr<FForwardRewrite>("FQStateRewrite", AddQStateRewrite);


Expr ReluQStateRewrite(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx) {
return Expr(nullptr);
CHECK_EQ(new_args.size(), 1);
if (const auto* n = new_args[0].as<QIntStateNode>()) {
Expr ret = ForwardOp(ref_call, {n->data});
return QIntStateNode::make(ret, n->dom_scale, n->safe_nbit, n->dtype);
// } else if (const auto* n = new_args[0].as<QRealStateNode>()){
// Expr ret = ForwardOp(ref_call, {n->data});
// return QRealStateNode::make(ret);
} else {
CHECK(!new_args[0]->derived_from<TempExprNode>());
return new_args[0];
// LOG(FATAL) << "wrong";
}
return new_args[0];
return Expr(nullptr);
}

RELAY_REGISTER_OP("nn.relu")
Expand Down

0 comments on commit 0b64a9f

Please sign in to comment.