Skip to content

Commit

Permalink
[QUANTIZE] Add nn.batch_flatten as quantizable. (#5805)
Browse files Browse the repository at this point in the history
* [ONNX] Skip ADD inside Gemm op when vector is zero

* [QUANTIZE] Add nn.batch_flatten as quantizable.
  • Loading branch information
cbalint13 authored Jun 21, 2020
1 parent 2dcfd61 commit 7902d0f
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 9 deletions.
25 changes: 17 additions & 8 deletions python/tvm/relay/quantize/_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,21 @@ def add_partition_generic(ref_call, new_args, ctx):

raise ValueError

def mul_partition_generic(ref_call, new_args, ctx):
"""Rewrite function for ewise mul for partition for generic devices"""
lhs_cond, lhs = partition_expr_check(new_args[0])
rhs_cond, rhs = partition_expr_check(new_args[1])

if lhs_cond:
# introduced by bn: multiply(out, scale)
return QPartitionExpr(_forward_op(ref_call, [lhs, rhs]))

if not lhs_cond and not rhs_cond:
# trivial case
return None

raise ValueError


# TODO(ziheng) enhance `register_partition_function` to dispatch
# for target automatically
Expand All @@ -136,11 +151,5 @@ def add_partition_function(ref_call, new_args, ctx):

@register_partition_function("multiply")
def multiply_partition_function(ref_call, new_args, ctx):
"""Rewrite function for ewise add for partition"""
lhs_cond, lhs = partition_expr_check(new_args[0])
rhs_cond, rhs = partition_expr_check(new_args[1])
if lhs_cond:
# introduced by bn: multiply(out, scale)
return QPartitionExpr(_forward_op(ref_call, [lhs, rhs]))
assert (not lhs_cond) and (not rhs_cond)
return None
"""Rewrite function for ewise multiply for partition"""
return mul_partition_generic(ref_call, new_args, ctx)
5 changes: 4 additions & 1 deletion src/relay/quantize/realize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ Expr MulRealize(const Call& ref_call, const Array<Expr>& new_args, const ObjectR
Expr dom_scale = FoldConstantOpt(mul);
return QRealizeIntExpr(ret, dom_scale, dtype);
}
CHECK(!new_args[0]->IsInstance<TempExprNode>() && !new_args[1]->IsInstance<TempExprNode>());
CHECK(!new_args[0]->IsInstance<TempExprNode>() || !new_args[1]->IsInstance<TempExprNode>());
return Expr(nullptr);
}

Expand Down Expand Up @@ -418,6 +418,9 @@ RELAY_REGISTER_OP("nn.relu").set_attr<FForwardRewrite>("FQRealizeRewrite", Ident

RELAY_REGISTER_OP("strided_slice").set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);

RELAY_REGISTER_OP("nn.batch_flatten")
.set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);

RELAY_REGISTER_OP("annotation.stop_fusion")
.set_attr<FForwardRewrite>("FQRealizeRewrite", IdentityRealize);

Expand Down
23 changes: 23 additions & 0 deletions tests/python/relay/test_pass_auto_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from tvm import te
from tvm import relay
from tvm.relay import testing
from tvm.relay.expr import Call


def quantize_and_build(out):
Expand All @@ -32,6 +33,7 @@ def quantize_and_build(out):

relay.build(qmod, "llvm", params=params)

return qmod

def test_mul_rewrite():
"""a test case where rhs of mul is not constant"""
Expand All @@ -49,6 +51,26 @@ def test_mul_rewrite():

quantize_and_build(act * pool)

def test_batch_flatten_rewrite():

data = relay.var("data", shape=(1, 16, 64, 64), dtype="float32")

out = relay.nn.conv2d(data, relay.var("weight"),
kernel_size=(3, 3),
padding=(1, 1),
channels=16)

out = relay.nn.batch_flatten(out)

qmod = quantize_and_build(out)

def _check_batch_flatten(node):
if isinstance(node, Call):
if(node.op.name == "nn.batch_flatten"):
assert node.checked_type.dtype == "int8"

# check if batch_flatten is quantized
relay.analysis.post_order_visit(qmod["main"], _check_batch_flatten)

def get_calibration_dataset(input_name):
dataset = []
Expand Down Expand Up @@ -83,6 +105,7 @@ def test_calibrate_memory_bound():

if __name__ == "__main__":
test_mul_rewrite()
test_batch_flatten_rewrite()
test_calibrate_target(False)
test_calibrate_target(True)
test_calibrate_memory_bound()

0 comments on commit 7902d0f

Please sign in to comment.