diff --git a/python/tvm/relay/quantize/_annotate.py b/python/tvm/relay/quantize/_annotate.py index 9d679d206508..ab98f3c369ab 100644 --- a/python/tvm/relay/quantize/_annotate.py +++ b/python/tvm/relay/quantize/_annotate.py @@ -214,8 +214,10 @@ def multiply_rewrite(ref_call, new_args, ctx): # quantize lhs to INPUT field if lhs_kind == QAnnotateKind.ACTIVATION: lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT) - # quantize rhs to WEIGHT field - rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT) + if _analysis.check_constant(rhs_expr): + rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT) + else: + rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT) expr = _forward_op(ref_call, [lhs_expr, rhs_expr]) return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION) diff --git a/src/relay/pass/quantize/realize.cc b/src/relay/pass/quantize/realize.cc index 4cf84f43e90b..773551a4c690 100644 --- a/src/relay/pass/quantize/realize.cc +++ b/src/relay/pass/quantize/realize.cc @@ -278,13 +278,9 @@ Expr MulRealize(const Call& ref_call, DataType dtype = cfg->dtype_activation; if (lhs->dtype != dtype) { ldata = Cast(ldata, dtype); - } else { - CHECK_EQ(lhs->dtype, dtype); } if (rhs->dtype != dtype) { rdata = Cast(rdata, dtype); - } else { - CHECK_EQ(rhs->dtype, dtype); } Expr ret = ForwardOp(ref_call, {ldata, rdata}); @@ -499,6 +495,9 @@ Expr AvgPoolRealize(const Call& ref_call, RELAY_REGISTER_OP("nn.avg_pool2d") .set_attr("FQRealizeRewrite", AvgPoolRealize); +RELAY_REGISTER_OP("nn.global_avg_pool2d") +.set_attr("FQRealizeRewrite", AvgPoolRealize); + Expr CastHintRealize(const Call& ref_call, const Array& new_args, const NodeRef& ctx) { diff --git a/tests/python/relay/test_pass_auto_quantize.py b/tests/python/relay/test_pass_auto_quantize.py new file mode 100644 index 000000000000..e4aa36bf9f70 --- /dev/null +++ b/tests/python/relay/test_pass_auto_quantize.py @@ -0,0 +1,49 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import relay +from tvm.relay import testing + + +def quantize_and_build(out): + f = relay.Function(relay.analysis.free_vars(out), out) + mod, params = testing.create_workload(f) + + with relay.quantize.qconfig(skip_conv_layers=[]): + qmod = relay.quantize.quantize(mod, params) + + relay.build(qmod, "llvm", params=params) + + +def test_mul_rewrite(): + """a test case where rhs of mul is not constant""" + data = relay.var("data", shape=(1, 16, 64, 64)) + multiplier = relay.sigmoid(relay.var("data", shape=(1, 16, 1, 1))) + conv = relay.nn.conv2d(data, relay.var("weight"), + kernel_size=(3, 3), + padding=(1, 1), + channels=16) + act = relay.nn.relu(data=conv) + + quantize_and_build(act * multiplier) + + pool = relay.nn.global_avg_pool2d(data=act) + + quantize_and_build(act * pool) + +if __name__ == "__main__": + test_mul_rewrite()